mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -21,6 +21,7 @@ import torch
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
||||||
@@ -39,6 +40,7 @@ def get_args():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
@@ -67,4 +69,3 @@ if __name__ == "__main__":
|
|||||||
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
|
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
|
||||||
|
|
||||||
torch.save(hf_tensors, args.output_path)
|
torch.save(hf_tensors, args.output_path)
|
||||||
|
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
|
|||||||
print(f"Unexpected token: {token_str}")
|
print(f"Unexpected token: {token_str}")
|
||||||
return speech_ids
|
return speech_ids
|
||||||
|
|
||||||
|
|
||||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||||
speech_id_str = ""
|
speech_id_str = ""
|
||||||
@@ -182,14 +183,13 @@ def get_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||||
"""Simplified data collator for batch_size=1 processing"""
|
"""Simplified data collator for batch_size=1 processing"""
|
||||||
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||||
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||||
mels, prompt_audio_cosy2tokens_list = [], []
|
mels, prompt_audio_cosy2tokens_list = [], []
|
||||||
for i, item in enumerate(batch):
|
for item in batch:
|
||||||
prompt_text, target_text = (
|
prompt_text, target_text = (
|
||||||
item["prompt_text"],
|
item["prompt_text"],
|
||||||
item["target_text"],
|
item["target_text"],
|
||||||
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|||||||
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||||
for i in range(len(codes)):
|
for i in range(len(codes)):
|
||||||
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||||
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
|
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
|
||||||
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||||
# Create chat template for LLM generation
|
# Create chat template for LLM generation
|
||||||
chat = [
|
chat = [
|
||||||
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|||||||
)
|
)
|
||||||
input_ids_list.append(input_ids.squeeze(0))
|
input_ids_list.append(input_ids.squeeze(0))
|
||||||
|
|
||||||
|
|
||||||
# For batch_size=1, no need to pad
|
# For batch_size=1, no need to pad
|
||||||
if len(input_ids_list) == 1:
|
if len(input_ids_list) == 1:
|
||||||
input_ids = input_ids_list[0].unsqueeze(0)
|
input_ids = input_ids_list[0].unsqueeze(0)
|
||||||
@@ -384,7 +383,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||||
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
progress_bar.update(world_size * len(batch["ids"]))
|
progress_bar.update(world_size * len(batch["ids"]))
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,6 @@ import datasets
|
|||||||
|
|
||||||
from verl.utils.hdfs_io import copy, makedirs
|
from verl.utils.hdfs_io import copy, makedirs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
|
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ import torch
|
|||||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
||||||
@@ -96,17 +95,20 @@ if __name__ == "__main__":
|
|||||||
# set the weight and bias of the new lm_head to 0
|
# set the weight and bias of the new lm_head to 0
|
||||||
new_lm_head.weight.data.zero_()
|
new_lm_head.weight.data.zero_()
|
||||||
new_lm_head.bias.data.zero_()
|
new_lm_head.bias.data.zero_()
|
||||||
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.weight
|
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
|
||||||
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.bias
|
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
|
||||||
|
|
||||||
llm.lm_head = new_lm_head
|
llm.lm_head = new_lm_head
|
||||||
input_embeddings = llm.get_input_embeddings()
|
input_embeddings = llm.get_input_embeddings()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = speech_embedding.weight
|
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
|
||||||
input_embeddings.weight[original_tokenizer_vocab_size+cosyvoice2_token_size+3:original_tokenizer_vocab_size+cosyvoice2_token_size+3+2] = llm_embedding.weight
|
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
|
||||||
|
|
||||||
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, original_tokenizer_vocab_size + cosyvoice2_token_size + 1, original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
|
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
|
||||||
|
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
|
||||||
|
original_tokenizer_vocab_size + cosyvoice2_token_size + 2,
|
||||||
|
original_tokenizer_vocab_size + cosyvoice2_token_size + 3]
|
||||||
llm.generation_config.eos_token_id = eos_token_ids
|
llm.generation_config.eos_token_id = eos_token_ids
|
||||||
llm.generation_config.temperature = 1.0
|
llm.generation_config.temperature = 1.0
|
||||||
llm.generation_config.top_p = 0.8
|
llm.generation_config.top_p = 0.8
|
||||||
|
|||||||
@@ -18,7 +18,10 @@ Reward calculation for CosyVoice2-0.5B.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os, re, warnings, json, time, argparse
|
import re
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -31,6 +34,7 @@ REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
|
|||||||
def _parse_ids(token_str: str) -> List[int]:
|
def _parse_ids(token_str: str) -> List[int]:
|
||||||
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
|
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
|
||||||
|
|
||||||
|
|
||||||
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
|
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
|
||||||
"""Send token IDs and ground-truth text to the Triton server and get reward."""
|
"""Send token IDs and ground-truth text to the Triton server and get reward."""
|
||||||
|
|
||||||
@@ -100,7 +104,6 @@ def compute_score(
|
|||||||
try:
|
try:
|
||||||
reward = _remote_reward(ids, ground_truth)
|
reward = _remote_reward(ids, ground_truth)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(f"Remote reward server error: {e}; returning 0.0")
|
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
|
||||||
if debug_dump:
|
if debug_dump:
|
||||||
@@ -110,6 +113,7 @@ def compute_score(
|
|||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
# CLI quick test
|
# CLI quick test
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
@@ -141,7 +145,6 @@ if __name__ == "__main__":
|
|||||||
help="Run in non-interactive mode (process all samples without prompts)"
|
help="Run in non-interactive mode (process all samples without prompts)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--debug",
|
"--debug",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ import string
|
|||||||
punctuation_all = punctuation + string.punctuation
|
punctuation_all = punctuation + string.punctuation
|
||||||
Pathlike = Union[str, Path]
|
Pathlike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
def remove_punctuation(text: str) -> str:
|
def remove_punctuation(text: str) -> str:
|
||||||
for x in punctuation_all:
|
for x in punctuation_all:
|
||||||
if x == '\'':
|
if x == '\'':
|
||||||
@@ -109,6 +110,7 @@ def remove_punctuation(text: str) -> str:
|
|||||||
text = text.replace(x, '')
|
text = text.replace(x, '')
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def store_transcripts(
|
def store_transcripts(
|
||||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -304,6 +306,7 @@ def write_error_stats(
|
|||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@@ -590,7 +593,7 @@ def normalize_text_alimeeting(text: str) -> str:
|
|||||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
text = text.replace('\u00A0', '') # test_hard
|
text = text.replace('\u00A0', '') # test_hard
|
||||||
text = text.replace(" ", "")
|
text = text.replace(" ", "")
|
||||||
text = text.replace("<sil>", "")
|
text = text.replace("<sil>", "")
|
||||||
text = text.replace("<%>", "")
|
text = text.replace("<%>", "")
|
||||||
|
|||||||
@@ -14,6 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Pytriton server for token2wav conversion and ASR"""
|
"""Pytriton server for token2wav conversion and ASR"""
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
from omnisense.models import OmniSenseVoiceSmall
|
||||||
|
from pytriton.proxy.types import Request
|
||||||
|
from pytriton.triton import Triton, TritonConfig
|
||||||
|
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||||
|
from pytriton.decorators import batch
|
||||||
import argparse
|
import argparse
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
@@ -37,15 +44,6 @@ zh_tn_model = ZhNormalizer(
|
|||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytriton.decorators import batch
|
|
||||||
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
|
||||||
from pytriton.triton import Triton, TritonConfig
|
|
||||||
from pytriton.proxy.types import Request
|
|
||||||
|
|
||||||
from omnisense.models import OmniSenseVoiceSmall
|
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
@@ -78,7 +76,6 @@ class _ASR_Server:
|
|||||||
return {"TRANSCRIPTS": transcripts}
|
return {"TRANSCRIPTS": transcripts}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def audio_decode_cosyvoice2(
|
def audio_decode_cosyvoice2(
|
||||||
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||||
):
|
):
|
||||||
@@ -141,6 +138,7 @@ def get_random_prompt_from_dataset(dataset):
|
|||||||
prompt_text = prompt_text.replace(" ", "")
|
prompt_text = prompt_text.replace(" ", "")
|
||||||
return prompt_text, prompt_speech_16k
|
return prompt_text, prompt_speech_16k
|
||||||
|
|
||||||
|
|
||||||
class _Token2Wav_ASR:
|
class _Token2Wav_ASR:
|
||||||
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||||
|
|
||||||
@@ -163,6 +161,7 @@ class _Token2Wav_ASR:
|
|||||||
self.codec_decoder = CosyVoice2(
|
self.codec_decoder = CosyVoice2(
|
||||||
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
|
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@batch
|
@batch
|
||||||
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
|
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
|
||||||
"""
|
"""
|
||||||
@@ -236,7 +235,6 @@ class _Token2Wav_ASR:
|
|||||||
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||||
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
|
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
|
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user