This commit is contained in:
root
2025-07-29 08:40:51 +00:00
parent d1c354eac7
commit 62d082634e
7 changed files with 71 additions and 68 deletions

View File

@@ -21,6 +21,7 @@ import torch
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer from transformers import AutoTokenizer
def get_args(): def get_args():
parser = ArgumentParser() parser = ArgumentParser()
@@ -39,6 +40,7 @@ def get_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
@@ -67,4 +69,3 @@ if __name__ == "__main__":
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"] hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
torch.save(hf_tensors, args.output_path) torch.save(hf_tensors, args.output_path)

View File

@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
print(f"Unexpected token: {token_str}") print(f"Unexpected token: {token_str}")
return speech_ids return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens): def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>""" """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = "" speech_id_str = ""
@@ -182,14 +183,13 @@ def get_args():
return args return args
def data_collator(batch, tokenizer, s3_tokenizer): def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing""" """Simplified data collator for batch_size=1 processing"""
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu") device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
mels, prompt_audio_cosy2tokens_list = [], [] mels, prompt_audio_cosy2tokens_list = [], []
for i, item in enumerate(batch): for item in batch:
prompt_text, target_text = ( prompt_text, target_text = (
item["prompt_text"], item["prompt_text"],
item["target_text"], item["target_text"],
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device)) codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)): for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()]) prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list): for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens) prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation # Create chat template for LLM generation
chat = [ chat = [
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
) )
input_ids_list.append(input_ids.squeeze(0)) input_ids_list.append(input_ids.squeeze(0))
# For batch_size=1, no need to pad # For batch_size=1, no need to pad
if len(input_ids_list) == 1: if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0) input_ids = input_ids_list[0].unsqueeze(0)
@@ -256,7 +255,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
for input_ids in input_ids_list for input_ids in input_ids_list
] ]
input_ids = torch.stack(input_ids_list) input_ids = torch.stack(input_ids_list)
ids = [item["id"] for item in batch] ids = [item["id"] for item in batch]
return { return {
@@ -287,7 +286,7 @@ def main():
assert torch.cuda.is_available() assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed() world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
# Load LLM model and tokenizer directly # Load LLM model and tokenizer directly
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path) model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
@@ -329,7 +328,7 @@ def main():
for batch in dataloader: for batch in dataloader:
with torch.no_grad(): with torch.no_grad():
input_ids = batch["input_ids"].to(device) input_ids = batch["input_ids"].to(device)
# Generate speech tokens using LLM # Generate speech tokens using LLM
outputs = model.generate( outputs = model.generate(
input_ids, input_ids,
@@ -339,31 +338,31 @@ def main():
temperature=args.temperature, temperature=args.temperature,
top_k=args.top_k, top_k=args.top_k,
) )
# Process each sample in the batch # Process each sample in the batch
for i in range(len(batch["ids"])): for i in range(len(batch["ids"])):
# Extract generated tokens (excluding input) # Extract generated tokens (excluding input)
input_length = input_ids[i].shape[0] input_length = input_ids[i].shape[0]
generated_ids = outputs[i][input_length:-1] # Remove last token if needed generated_ids = outputs[i][input_length:-1] # Remove last token if needed
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract speech IDs from token strings like <|s_23456|> # Extract speech IDs from token strings like <|s_23456|>
speech_ids = extract_speech_ids(speech_tokens_str) speech_ids = extract_speech_ids(speech_tokens_str)
if len(speech_ids) == 0: if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue continue
# Convert to tensor for CosyVoice2 # Convert to tensor for CosyVoice2
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0) audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
if args.prompt_text is not None: if args.prompt_text is not None:
current_prompt_text = args.prompt_text current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k current_prompt_audio = prompt_speech_16k
else: else:
current_prompt_text = batch["prompt_text"][i] current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i] current_prompt_audio = batch["prompt_audio_list"][i]
if current_prompt_audio is not None: if current_prompt_audio is not None:
# Generate audio using CosyVoice2 # Generate audio using CosyVoice2
audio_hat = audio_decode_cosyvoice2( audio_hat = audio_decode_cosyvoice2(
@@ -372,18 +371,17 @@ def main():
current_prompt_audio, current_prompt_audio,
cosyvoice_codec, cosyvoice_codec,
) )
# Convert to numpy and save # Convert to numpy and save
generated_wave = audio_hat.squeeze(0).cpu().numpy() generated_wave = audio_hat.squeeze(0).cpu().numpy()
target_sample_rate = 24000 target_sample_rate = 24000
utt = batch["ids"][i] utt = batch["ids"][i]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate) sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens") print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
else: else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping") print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
if rank == 0: if rank == 0:
progress_bar.update(world_size * len(batch["ids"])) progress_bar.update(world_size * len(batch["ids"]))

View File

@@ -23,8 +23,6 @@ import datasets
from verl.utils.hdfs_io import copy, makedirs from verl.utils.hdfs_io import copy, makedirs
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file") parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")

View File

@@ -31,7 +31,6 @@ import torch
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_args(): def get_args():
parser = ArgumentParser() parser = ArgumentParser()
@@ -96,17 +95,20 @@ if __name__ == "__main__":
# set the weight and bias of the new lm_head to 0 # set the weight and bias of the new lm_head to 0
new_lm_head.weight.data.zero_() new_lm_head.weight.data.zero_()
new_lm_head.bias.data.zero_() new_lm_head.bias.data.zero_()
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.weight new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.bias new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
llm.lm_head = new_lm_head llm.lm_head = new_lm_head
input_embeddings = llm.get_input_embeddings() input_embeddings = llm.get_input_embeddings()
with torch.no_grad(): with torch.no_grad():
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = speech_embedding.weight input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
input_embeddings.weight[original_tokenizer_vocab_size+cosyvoice2_token_size+3:original_tokenizer_vocab_size+cosyvoice2_token_size+3+2] = llm_embedding.weight input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, original_tokenizer_vocab_size + cosyvoice2_token_size + 1, original_tokenizer_vocab_size + cosyvoice2_token_size + 2] eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
original_tokenizer_vocab_size + cosyvoice2_token_size + 2,
original_tokenizer_vocab_size + cosyvoice2_token_size + 3]
llm.generation_config.eos_token_id = eos_token_ids llm.generation_config.eos_token_id = eos_token_ids
llm.generation_config.temperature = 1.0 llm.generation_config.temperature = 1.0
llm.generation_config.top_p = 0.8 llm.generation_config.top_p = 0.8
@@ -121,4 +123,4 @@ if __name__ == "__main__":
TEMPLATE = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- '<|sos|>' + message['content'] + '<|task_id|>' }}{%- elif message['role'] == 'assistant' %}{{- message['content']}}{%- endif %}{%- endfor %}" TEMPLATE = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- '<|sos|>' + message['content'] + '<|task_id|>' }}{%- elif message['role'] == 'assistant' %}{{- message['content']}}{%- endif %}{%- endfor %}"
tokenizer.chat_template = TEMPLATE tokenizer.chat_template = TEMPLATE
tokenizer.save_pretrained(args.save_path) tokenizer.save_pretrained(args.save_path)

View File

@@ -18,7 +18,10 @@ Reward calculation for CosyVoice2-0.5B.
from __future__ import annotations from __future__ import annotations
import os, re, warnings, json, time, argparse import re
import json
import time
import argparse
from typing import List from typing import List
import numpy as np import numpy as np
@@ -31,6 +34,7 @@ REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
def _parse_ids(token_str: str) -> List[int]: def _parse_ids(token_str: str) -> List[int]:
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)] return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float: def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
"""Send token IDs and ground-truth text to the Triton server and get reward.""" """Send token IDs and ground-truth text to the Triton server and get reward."""
@@ -100,7 +104,6 @@ def compute_score(
try: try:
reward = _remote_reward(ids, ground_truth) reward = _remote_reward(ids, ground_truth)
except Exception as e: except Exception as e:
warnings.warn(f"Remote reward server error: {e}; returning 0.0")
reward = 0.0 reward = 0.0
if debug_dump: if debug_dump:
@@ -110,46 +113,46 @@ def compute_score(
return reward return reward
# CLI quick test # CLI quick test
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def get_args(): def get_args():
"""Parse command line arguments.""" """Parse command line arguments."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Test TTS CER scoring with data from JSONL file", description="Test TTS CER scoring with data from JSONL file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument( parser.add_argument(
"--input", "-i", "--input", "-i",
type=str, type=str,
default="data/emilia_zh-cosy-tiny-test.jsonl", default="data/emilia_zh-cosy-tiny-test.jsonl",
help="Path to input JSONL file" help="Path to input JSONL file"
) )
parser.add_argument( parser.add_argument(
"--max-samples", "-n", "--max-samples", "-n",
type=int, type=int,
default=None, default=None,
help="Maximum number of samples to process (default: all)" help="Maximum number of samples to process (default: all)"
) )
parser.add_argument( parser.add_argument(
"--no-interactive", "--no-interactive",
action="store_true", action="store_true",
help="Run in non-interactive mode (process all samples without prompts)" help="Run in non-interactive mode (process all samples without prompts)"
) )
parser.add_argument( parser.add_argument(
"--debug", "--debug",
action="store_true", action="store_true",
help="Enable debug mode" help="Enable debug mode"
) )
return parser.parse_args() return parser.parse_args()
def load_jsonl(file_path: str): def load_jsonl(file_path: str):
"""Load data from jsonl file.""" """Load data from jsonl file."""
data = [] data = []
@@ -157,37 +160,37 @@ if __name__ == "__main__":
for line in f: for line in f:
data.append(json.loads(line.strip())) data.append(json.loads(line.strip()))
return data return data
def code_to_solution_str(code_list: List[int]) -> str: def code_to_solution_str(code_list: List[int]) -> str:
"""Convert code list to solution string format.""" """Convert code list to solution string format."""
return ''.join([f"<|s_{code}|>" for code in code_list]) return ''.join([f"<|s_{code}|>" for code in code_list])
# Parse command line arguments # Parse command line arguments
args = get_args() args = get_args()
try: try:
# Load data from jsonl file # Load data from jsonl file
print(f"Loading data from: {args.input}") print(f"Loading data from: {args.input}")
data_list = load_jsonl(args.input) data_list = load_jsonl(args.input)
print(f"Loaded {len(data_list)} samples") print(f"Loaded {len(data_list)} samples")
# Limit samples if specified # Limit samples if specified
if args.max_samples is not None: if args.max_samples is not None:
data_list = data_list[:args.max_samples] data_list = data_list[:args.max_samples]
print(f"Processing first {len(data_list)} samples (limited by --max-samples)") print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
# Process each sample # Process each sample
begin_time = time.time() begin_time = time.time()
for i, sample in enumerate(data_list): for i, sample in enumerate(data_list):
print(f"\n--- Sample {i+1}/{len(data_list)} ---") print(f"\n--- Sample {i+1}/{len(data_list)} ---")
print(f"Index: {sample.get('index', 'unknown')}") print(f"Index: {sample.get('index', 'unknown')}")
print(f"Text: {sample['text']}") print(f"Text: {sample['text']}")
# Extract required fields # Extract required fields
code_list = sample['code'] code_list = sample['code']
ground_truth = sample['text'] ground_truth = sample['text']
data_source = sample.get('index', f'sample_{i}') # Use index as data_source data_source = sample.get('index', f'sample_{i}') # Use index as data_source
# Convert code list to solution string # Convert code list to solution string
solution_str = code_to_solution_str(code_list) solution_str = code_to_solution_str(code_list)
print(f"Solution tokens: {len(code_list)} tokens") print(f"Solution tokens: {len(code_list)} tokens")
@@ -195,7 +198,7 @@ if __name__ == "__main__":
print(f"Solution string: {solution_str}") print(f"Solution string: {solution_str}")
else: else:
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}") print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
# Call compute_score function # Call compute_score function
try: try:
score = compute_score( score = compute_score(
@@ -208,7 +211,7 @@ if __name__ == "__main__":
print(f"Final Score: {score:.4f}") print(f"Final Score: {score:.4f}")
except Exception as e: except Exception as e:
print(f"Error computing score: {e}") print(f"Error computing score: {e}")
# Ask user if they want to continue (for interactive mode) # Ask user if they want to continue (for interactive mode)
if not args.no_interactive and i < len(data_list) - 1: if not args.no_interactive and i < len(data_list) - 1:
try: try:
@@ -218,7 +221,7 @@ if __name__ == "__main__":
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nStopped by user") print("\nStopped by user")
break break
print(f"\nProcessed {min(i+1, len(data_list))} samples") print(f"\nProcessed {min(i+1, len(data_list))} samples")
end_time = time.time() end_time = time.time()
print(f"Time taken: {end_time - begin_time} seconds") print(f"Time taken: {end_time - begin_time} seconds")

View File

@@ -102,6 +102,7 @@ import string
punctuation_all = punctuation + string.punctuation punctuation_all = punctuation + string.punctuation
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
def remove_punctuation(text: str) -> str: def remove_punctuation(text: str) -> str:
for x in punctuation_all: for x in punctuation_all:
if x == '\'': if x == '\'':
@@ -109,6 +110,7 @@ def remove_punctuation(text: str) -> str:
text = text.replace(x, '') text = text.replace(x, '')
return text return text
def store_transcripts( def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None: ) -> None:
@@ -304,6 +306,7 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate) return float(tot_err_rate)
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -533,7 +536,7 @@ def get_args():
default=None, default=None,
help="wav_base_name label", help="wav_base_name label",
) )
# Dataset related arguments for loading labels when label file is not provided # Dataset related arguments for loading labels when label file is not provided
parser.add_argument( parser.add_argument(
"--dataset-name", "--dataset-name",
@@ -541,14 +544,14 @@ def get_args():
default="yuekai/seed_tts_cosy2", default="yuekai/seed_tts_cosy2",
help="Huggingface dataset name for loading labels", help="Huggingface dataset name for loading labels",
) )
parser.add_argument( parser.add_argument(
"--split-name", "--split-name",
type=str, type=str,
default="wenetspeech4tts", default="wenetspeech4tts",
help="Dataset split name for loading labels", help="Dataset split name for loading labels",
) )
return parser.parse_args() return parser.parse_args()
@@ -590,7 +593,7 @@ def normalize_text_alimeeting(text: str) -> str:
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
""" """
import re import re
text = text.replace('\u00A0', '') # test_hard text = text.replace('\u00A0', '') # test_hard
text = text.replace(" ", "") text = text.replace(" ", "")
text = text.replace("<sil>", "") text = text.replace("<sil>", "")
text = text.replace("<%>", "") text = text.replace("<%>", "")
@@ -685,10 +688,10 @@ def main():
print( print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
) )
# Load labels either from file or from dataset # Load labels either from file or from dataset
labels_dict = {} labels_dict = {}
if args.label: if args.label:
# Load labels from file (original functionality) # Load labels from file (original functionality)
print(f"Loading labels from file: {args.label}") print(f"Loading labels from file: {args.label}")
@@ -716,11 +719,11 @@ def main():
split=args.split_name, split=args.split_name,
trust_remote_code=True, trust_remote_code=True,
) )
for item in dataset: for item in dataset:
audio_id = item["id"] audio_id = item["id"]
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"]) labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
print(f"Loaded {len(labels_dict)} labels from dataset") print(f"Loaded {len(labels_dict)} labels from dataset")
# Perform evaluation if labels are available # Perform evaluation if labels are available
@@ -750,4 +753,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -14,6 +14,13 @@
# limitations under the License. # limitations under the License.
"""Pytriton server for token2wav conversion and ASR""" """Pytriton server for token2wav conversion and ASR"""
from datasets import load_dataset
from cosyvoice.cli.cosyvoice import CosyVoice2
from omnisense.models import OmniSenseVoiceSmall
from pytriton.proxy.types import Request
from pytriton.triton import Triton, TritonConfig
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.decorators import batch
import argparse import argparse
import io import io
import logging import logging
@@ -37,15 +44,6 @@ zh_tn_model = ZhNormalizer(
overwrite_cache=True, overwrite_cache=True,
) )
from pytriton.decorators import batch
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.triton import Triton, TritonConfig
from pytriton.proxy.types import Request
from omnisense.models import OmniSenseVoiceSmall
from cosyvoice.cli.cosyvoice import CosyVoice2
from datasets import load_dataset
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
@@ -78,7 +76,6 @@ class _ASR_Server:
return {"TRANSCRIPTS": transcripts} return {"TRANSCRIPTS": transcripts}
def audio_decode_cosyvoice2( def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
): ):
@@ -123,12 +120,12 @@ def get_random_prompt_from_dataset(dataset):
""" """
random_idx = random.randint(0, len(dataset) - 1) random_idx = random.randint(0, len(dataset) - 1)
sample = dataset[random_idx] sample = dataset[random_idx]
# Extract audio data # Extract audio data
audio_data = sample["audio"] audio_data = sample["audio"]
audio_array = audio_data["array"] audio_array = audio_data["array"]
sample_rate = audio_data["sampling_rate"] sample_rate = audio_data["sampling_rate"]
# Convert audio to 16kHz if needed # Convert audio to 16kHz if needed
if sample_rate != 16000: if sample_rate != 16000:
num_samples = int(len(audio_array) * (16000 / sample_rate)) num_samples = int(len(audio_array) * (16000 / sample_rate))
@@ -141,6 +138,7 @@ def get_random_prompt_from_dataset(dataset):
prompt_text = prompt_text.replace(" ", "") prompt_text = prompt_text.replace(" ", "")
return prompt_text, prompt_speech_16k return prompt_text, prompt_speech_16k
class _Token2Wav_ASR: class _Token2Wav_ASR:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton.""" """Wraps a single OmniSenseVoiceSmall model instance for Triton."""
@@ -163,6 +161,7 @@ class _Token2Wav_ASR:
self.codec_decoder = CosyVoice2( self.codec_decoder = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
) )
@batch @batch
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray): def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
""" """
@@ -236,7 +235,6 @@ class _Token2Wav_ASR:
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1) rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts} return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}