mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
fix lint
This commit is contained in:
@@ -21,6 +21,7 @@ import torch
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
@@ -39,6 +40,7 @@ def get_args():
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
@@ -67,4 +69,3 @@ if __name__ == "__main__":
|
||||
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
|
||||
|
||||
torch.save(hf_tensors, args.output_path)
|
||||
|
||||
|
||||
@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
|
||||
print(f"Unexpected token: {token_str}")
|
||||
return speech_ids
|
||||
|
||||
|
||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||
speech_id_str = ""
|
||||
@@ -182,14 +183,13 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
|
||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
"""Simplified data collator for batch_size=1 processing"""
|
||||
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||
mels, prompt_audio_cosy2tokens_list = [], []
|
||||
for i, item in enumerate(batch):
|
||||
for item in batch:
|
||||
prompt_text, target_text = (
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||
for i in range(len(codes)):
|
||||
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
|
||||
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
|
||||
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||
# Create chat template for LLM generation
|
||||
chat = [
|
||||
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
)
|
||||
input_ids_list.append(input_ids.squeeze(0))
|
||||
|
||||
|
||||
# For batch_size=1, no need to pad
|
||||
if len(input_ids_list) == 1:
|
||||
input_ids = input_ids_list[0].unsqueeze(0)
|
||||
@@ -256,7 +255,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
input_ids = torch.stack(input_ids_list)
|
||||
|
||||
|
||||
ids = [item["id"] for item in batch]
|
||||
|
||||
return {
|
||||
@@ -287,7 +286,7 @@ def main():
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
|
||||
# Load LLM model and tokenizer directly
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||
@@ -329,7 +328,7 @@ def main():
|
||||
for batch in dataloader:
|
||||
with torch.no_grad():
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
|
||||
|
||||
# Generate speech tokens using LLM
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
@@ -339,31 +338,31 @@ def main():
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
)
|
||||
|
||||
|
||||
# Process each sample in the batch
|
||||
for i in range(len(batch["ids"])):
|
||||
# Extract generated tokens (excluding input)
|
||||
input_length = input_ids[i].shape[0]
|
||||
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
|
||||
# Extract speech IDs from token strings like <|s_23456|>
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
|
||||
|
||||
if len(speech_ids) == 0:
|
||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||
continue
|
||||
|
||||
|
||||
# Convert to tensor for CosyVoice2
|
||||
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
|
||||
if args.prompt_text is not None:
|
||||
current_prompt_text = args.prompt_text
|
||||
current_prompt_audio = prompt_speech_16k
|
||||
else:
|
||||
current_prompt_text = batch["prompt_text"][i]
|
||||
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||
|
||||
|
||||
if current_prompt_audio is not None:
|
||||
# Generate audio using CosyVoice2
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
@@ -372,18 +371,17 @@ def main():
|
||||
current_prompt_audio,
|
||||
cosyvoice_codec,
|
||||
)
|
||||
|
||||
|
||||
# Convert to numpy and save
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
target_sample_rate = 24000
|
||||
|
||||
|
||||
utt = batch["ids"][i]
|
||||
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||
|
||||
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
|
||||
else:
|
||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
|
||||
@@ -23,8 +23,6 @@ import datasets
|
||||
|
||||
from verl.utils.hdfs_io import copy, makedirs
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
|
||||
|
||||
@@ -31,7 +31,6 @@ import torch
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
@@ -96,17 +95,20 @@ if __name__ == "__main__":
|
||||
# set the weight and bias of the new lm_head to 0
|
||||
new_lm_head.weight.data.zero_()
|
||||
new_lm_head.bias.data.zero_()
|
||||
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.weight
|
||||
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.bias
|
||||
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
|
||||
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
|
||||
|
||||
llm.lm_head = new_lm_head
|
||||
input_embeddings = llm.get_input_embeddings()
|
||||
|
||||
with torch.no_grad():
|
||||
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = speech_embedding.weight
|
||||
input_embeddings.weight[original_tokenizer_vocab_size+cosyvoice2_token_size+3:original_tokenizer_vocab_size+cosyvoice2_token_size+3+2] = llm_embedding.weight
|
||||
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
|
||||
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
|
||||
|
||||
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, original_tokenizer_vocab_size + cosyvoice2_token_size + 1, original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
|
||||
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 2,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 3]
|
||||
llm.generation_config.eos_token_id = eos_token_ids
|
||||
llm.generation_config.temperature = 1.0
|
||||
llm.generation_config.top_p = 0.8
|
||||
@@ -121,4 +123,4 @@ if __name__ == "__main__":
|
||||
|
||||
TEMPLATE = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- '<|sos|>' + message['content'] + '<|task_id|>' }}{%- elif message['role'] == 'assistant' %}{{- message['content']}}{%- endif %}{%- endfor %}"
|
||||
tokenizer.chat_template = TEMPLATE
|
||||
tokenizer.save_pretrained(args.save_path)
|
||||
tokenizer.save_pretrained(args.save_path)
|
||||
|
||||
@@ -18,7 +18,10 @@ Reward calculation for CosyVoice2-0.5B.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os, re, warnings, json, time, argparse
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
@@ -31,6 +34,7 @@ REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
|
||||
def _parse_ids(token_str: str) -> List[int]:
|
||||
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
|
||||
|
||||
|
||||
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
|
||||
"""Send token IDs and ground-truth text to the Triton server and get reward."""
|
||||
|
||||
@@ -100,7 +104,6 @@ def compute_score(
|
||||
try:
|
||||
reward = _remote_reward(ids, ground_truth)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Remote reward server error: {e}; returning 0.0")
|
||||
reward = 0.0
|
||||
|
||||
if debug_dump:
|
||||
@@ -110,46 +113,46 @@ def compute_score(
|
||||
|
||||
return reward
|
||||
|
||||
|
||||
# CLI quick test
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def get_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test TTS CER scoring with data from JSONL file",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--input", "-i",
|
||||
type=str,
|
||||
default="data/emilia_zh-cosy-tiny-test.jsonl",
|
||||
help="Path to input JSONL file"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--max-samples", "-n",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of samples to process (default: all)"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--no-interactive",
|
||||
action="store_true",
|
||||
help="Run in non-interactive mode (process all samples without prompts)"
|
||||
)
|
||||
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode"
|
||||
)
|
||||
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_jsonl(file_path: str):
|
||||
"""Load data from jsonl file."""
|
||||
data = []
|
||||
@@ -157,37 +160,37 @@ if __name__ == "__main__":
|
||||
for line in f:
|
||||
data.append(json.loads(line.strip()))
|
||||
return data
|
||||
|
||||
|
||||
def code_to_solution_str(code_list: List[int]) -> str:
|
||||
"""Convert code list to solution string format."""
|
||||
return ''.join([f"<|s_{code}|>" for code in code_list])
|
||||
|
||||
|
||||
# Parse command line arguments
|
||||
args = get_args()
|
||||
|
||||
|
||||
try:
|
||||
# Load data from jsonl file
|
||||
print(f"Loading data from: {args.input}")
|
||||
data_list = load_jsonl(args.input)
|
||||
print(f"Loaded {len(data_list)} samples")
|
||||
|
||||
|
||||
# Limit samples if specified
|
||||
if args.max_samples is not None:
|
||||
data_list = data_list[:args.max_samples]
|
||||
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
|
||||
|
||||
|
||||
# Process each sample
|
||||
begin_time = time.time()
|
||||
for i, sample in enumerate(data_list):
|
||||
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
|
||||
print(f"Index: {sample.get('index', 'unknown')}")
|
||||
print(f"Text: {sample['text']}")
|
||||
|
||||
|
||||
# Extract required fields
|
||||
code_list = sample['code']
|
||||
ground_truth = sample['text']
|
||||
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
|
||||
|
||||
|
||||
# Convert code list to solution string
|
||||
solution_str = code_to_solution_str(code_list)
|
||||
print(f"Solution tokens: {len(code_list)} tokens")
|
||||
@@ -195,7 +198,7 @@ if __name__ == "__main__":
|
||||
print(f"Solution string: {solution_str}")
|
||||
else:
|
||||
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
|
||||
|
||||
|
||||
# Call compute_score function
|
||||
try:
|
||||
score = compute_score(
|
||||
@@ -208,7 +211,7 @@ if __name__ == "__main__":
|
||||
print(f"Final Score: {score:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Error computing score: {e}")
|
||||
|
||||
|
||||
# Ask user if they want to continue (for interactive mode)
|
||||
if not args.no_interactive and i < len(data_list) - 1:
|
||||
try:
|
||||
@@ -218,7 +221,7 @@ if __name__ == "__main__":
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopped by user")
|
||||
break
|
||||
|
||||
|
||||
print(f"\nProcessed {min(i+1, len(data_list))} samples")
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - begin_time} seconds")
|
||||
|
||||
@@ -102,6 +102,7 @@ import string
|
||||
punctuation_all = punctuation + string.punctuation
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
def remove_punctuation(text: str) -> str:
|
||||
for x in punctuation_all:
|
||||
if x == '\'':
|
||||
@@ -109,6 +110,7 @@ def remove_punctuation(text: str) -> str:
|
||||
text = text.replace(x, '')
|
||||
return text
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||
) -> None:
|
||||
@@ -304,6 +306,7 @@ def write_error_stats(
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@@ -533,7 +536,7 @@ def get_args():
|
||||
default=None,
|
||||
help="wav_base_name label",
|
||||
)
|
||||
|
||||
|
||||
# Dataset related arguments for loading labels when label file is not provided
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
@@ -541,14 +544,14 @@ def get_args():
|
||||
default="yuekai/seed_tts_cosy2",
|
||||
help="Huggingface dataset name for loading labels",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="Dataset split name for loading labels",
|
||||
)
|
||||
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -590,7 +593,7 @@ def normalize_text_alimeeting(text: str) -> str:
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
import re
|
||||
text = text.replace('\u00A0', '') # test_hard
|
||||
text = text.replace('\u00A0', '') # test_hard
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
@@ -685,10 +688,10 @@ def main():
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
|
||||
# Load labels either from file or from dataset
|
||||
labels_dict = {}
|
||||
|
||||
|
||||
if args.label:
|
||||
# Load labels from file (original functionality)
|
||||
print(f"Loading labels from file: {args.label}")
|
||||
@@ -716,11 +719,11 @@ def main():
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
for item in dataset:
|
||||
audio_id = item["id"]
|
||||
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
|
||||
|
||||
|
||||
print(f"Loaded {len(labels_dict)} labels from dataset")
|
||||
|
||||
# Perform evaluation if labels are available
|
||||
@@ -750,4 +753,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -14,6 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""Pytriton server for token2wav conversion and ASR"""
|
||||
|
||||
from datasets import load_dataset
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from omnisense.models import OmniSenseVoiceSmall
|
||||
from pytriton.proxy.types import Request
|
||||
from pytriton.triton import Triton, TritonConfig
|
||||
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||
from pytriton.decorators import batch
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
@@ -37,15 +44,6 @@ zh_tn_model = ZhNormalizer(
|
||||
overwrite_cache=True,
|
||||
)
|
||||
|
||||
from pytriton.decorators import batch
|
||||
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||
from pytriton.triton import Triton, TritonConfig
|
||||
from pytriton.proxy.types import Request
|
||||
|
||||
from omnisense.models import OmniSenseVoiceSmall
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
@@ -78,7 +76,6 @@ class _ASR_Server:
|
||||
return {"TRANSCRIPTS": transcripts}
|
||||
|
||||
|
||||
|
||||
def audio_decode_cosyvoice2(
|
||||
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||
):
|
||||
@@ -123,12 +120,12 @@ def get_random_prompt_from_dataset(dataset):
|
||||
"""
|
||||
random_idx = random.randint(0, len(dataset) - 1)
|
||||
sample = dataset[random_idx]
|
||||
|
||||
|
||||
# Extract audio data
|
||||
audio_data = sample["audio"]
|
||||
audio_array = audio_data["array"]
|
||||
sample_rate = audio_data["sampling_rate"]
|
||||
|
||||
|
||||
# Convert audio to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
num_samples = int(len(audio_array) * (16000 / sample_rate))
|
||||
@@ -141,6 +138,7 @@ def get_random_prompt_from_dataset(dataset):
|
||||
prompt_text = prompt_text.replace(" ", "")
|
||||
return prompt_text, prompt_speech_16k
|
||||
|
||||
|
||||
class _Token2Wav_ASR:
|
||||
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||
|
||||
@@ -163,6 +161,7 @@ class _Token2Wav_ASR:
|
||||
self.codec_decoder = CosyVoice2(
|
||||
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
|
||||
)
|
||||
|
||||
@batch
|
||||
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
|
||||
"""
|
||||
@@ -236,7 +235,6 @@ class _Token2Wav_ASR:
|
||||
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
|
||||
|
||||
|
||||
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user