This commit is contained in:
root
2025-07-29 08:40:51 +00:00
parent d1c354eac7
commit 62d082634e
7 changed files with 71 additions and 68 deletions

View File

@@ -21,6 +21,7 @@ import torch
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer from transformers import AutoTokenizer
def get_args(): def get_args():
parser = ArgumentParser() parser = ArgumentParser()
@@ -39,6 +40,7 @@ def get_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
@@ -67,4 +69,3 @@ if __name__ == "__main__":
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"] hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
torch.save(hf_tensors, args.output_path) torch.save(hf_tensors, args.output_path)

View File

@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
print(f"Unexpected token: {token_str}") print(f"Unexpected token: {token_str}")
return speech_ids return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens): def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>""" """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = "" speech_id_str = ""
@@ -182,14 +183,13 @@ def get_args():
return args return args
def data_collator(batch, tokenizer, s3_tokenizer): def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing""" """Simplified data collator for batch_size=1 processing"""
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu") device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
mels, prompt_audio_cosy2tokens_list = [], [] mels, prompt_audio_cosy2tokens_list = [], []
for i, item in enumerate(batch): for item in batch:
prompt_text, target_text = ( prompt_text, target_text = (
item["prompt_text"], item["prompt_text"],
item["target_text"], item["target_text"],
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device)) codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)): for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()]) prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list): for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens) prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation # Create chat template for LLM generation
chat = [ chat = [
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
) )
input_ids_list.append(input_ids.squeeze(0)) input_ids_list.append(input_ids.squeeze(0))
# For batch_size=1, no need to pad # For batch_size=1, no need to pad
if len(input_ids_list) == 1: if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0) input_ids = input_ids_list[0].unsqueeze(0)
@@ -384,7 +383,6 @@ def main():
else: else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping") print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
if rank == 0: if rank == 0:
progress_bar.update(world_size * len(batch["ids"])) progress_bar.update(world_size * len(batch["ids"]))

View File

@@ -23,8 +23,6 @@ import datasets
from verl.utils.hdfs_io import copy, makedirs from verl.utils.hdfs_io import copy, makedirs
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file") parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")

View File

@@ -31,7 +31,6 @@ import torch
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_args(): def get_args():
parser = ArgumentParser() parser = ArgumentParser()
@@ -96,17 +95,20 @@ if __name__ == "__main__":
# set the weight and bias of the new lm_head to 0 # set the weight and bias of the new lm_head to 0
new_lm_head.weight.data.zero_() new_lm_head.weight.data.zero_()
new_lm_head.bias.data.zero_() new_lm_head.bias.data.zero_()
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.weight new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.bias new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
llm.lm_head = new_lm_head llm.lm_head = new_lm_head
input_embeddings = llm.get_input_embeddings() input_embeddings = llm.get_input_embeddings()
with torch.no_grad(): with torch.no_grad():
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = speech_embedding.weight input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
input_embeddings.weight[original_tokenizer_vocab_size+cosyvoice2_token_size+3:original_tokenizer_vocab_size+cosyvoice2_token_size+3+2] = llm_embedding.weight input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, original_tokenizer_vocab_size + cosyvoice2_token_size + 1, original_tokenizer_vocab_size + cosyvoice2_token_size + 2] eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
original_tokenizer_vocab_size + cosyvoice2_token_size + 2,
original_tokenizer_vocab_size + cosyvoice2_token_size + 3]
llm.generation_config.eos_token_id = eos_token_ids llm.generation_config.eos_token_id = eos_token_ids
llm.generation_config.temperature = 1.0 llm.generation_config.temperature = 1.0
llm.generation_config.top_p = 0.8 llm.generation_config.top_p = 0.8

View File

@@ -18,7 +18,10 @@ Reward calculation for CosyVoice2-0.5B.
from __future__ import annotations from __future__ import annotations
import os, re, warnings, json, time, argparse import re
import json
import time
import argparse
from typing import List from typing import List
import numpy as np import numpy as np
@@ -31,6 +34,7 @@ REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
def _parse_ids(token_str: str) -> List[int]: def _parse_ids(token_str: str) -> List[int]:
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)] return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float: def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
"""Send token IDs and ground-truth text to the Triton server and get reward.""" """Send token IDs and ground-truth text to the Triton server and get reward."""
@@ -100,7 +104,6 @@ def compute_score(
try: try:
reward = _remote_reward(ids, ground_truth) reward = _remote_reward(ids, ground_truth)
except Exception as e: except Exception as e:
warnings.warn(f"Remote reward server error: {e}; returning 0.0")
reward = 0.0 reward = 0.0
if debug_dump: if debug_dump:
@@ -110,6 +113,7 @@ def compute_score(
return reward return reward
# CLI quick test # CLI quick test
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
@@ -141,7 +145,6 @@ if __name__ == "__main__":
help="Run in non-interactive mode (process all samples without prompts)" help="Run in non-interactive mode (process all samples without prompts)"
) )
parser.add_argument( parser.add_argument(
"--debug", "--debug",
action="store_true", action="store_true",

View File

@@ -102,6 +102,7 @@ import string
punctuation_all = punctuation + string.punctuation punctuation_all = punctuation + string.punctuation
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
def remove_punctuation(text: str) -> str: def remove_punctuation(text: str) -> str:
for x in punctuation_all: for x in punctuation_all:
if x == '\'': if x == '\'':
@@ -109,6 +110,7 @@ def remove_punctuation(text: str) -> str:
text = text.replace(x, '') text = text.replace(x, '')
return text return text
def store_transcripts( def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None: ) -> None:
@@ -304,6 +306,7 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate) return float(tot_err_rate)
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -590,7 +593,7 @@ def normalize_text_alimeeting(text: str) -> str:
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
""" """
import re import re
text = text.replace('\u00A0', '') # test_hard text = text.replace('\u00A0', '') # test_hard
text = text.replace(" ", "") text = text.replace(" ", "")
text = text.replace("<sil>", "") text = text.replace("<sil>", "")
text = text.replace("<%>", "") text = text.replace("<%>", "")

View File

@@ -14,6 +14,13 @@
# limitations under the License. # limitations under the License.
"""Pytriton server for token2wav conversion and ASR""" """Pytriton server for token2wav conversion and ASR"""
from datasets import load_dataset
from cosyvoice.cli.cosyvoice import CosyVoice2
from omnisense.models import OmniSenseVoiceSmall
from pytriton.proxy.types import Request
from pytriton.triton import Triton, TritonConfig
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.decorators import batch
import argparse import argparse
import io import io
import logging import logging
@@ -37,15 +44,6 @@ zh_tn_model = ZhNormalizer(
overwrite_cache=True, overwrite_cache=True,
) )
from pytriton.decorators import batch
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.triton import Triton, TritonConfig
from pytriton.proxy.types import Request
from omnisense.models import OmniSenseVoiceSmall
from cosyvoice.cli.cosyvoice import CosyVoice2
from datasets import load_dataset
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
@@ -78,7 +76,6 @@ class _ASR_Server:
return {"TRANSCRIPTS": transcripts} return {"TRANSCRIPTS": transcripts}
def audio_decode_cosyvoice2( def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
): ):
@@ -141,6 +138,7 @@ def get_random_prompt_from_dataset(dataset):
prompt_text = prompt_text.replace(" ", "") prompt_text = prompt_text.replace(" ", "")
return prompt_text, prompt_speech_16k return prompt_text, prompt_speech_16k
class _Token2Wav_ASR: class _Token2Wav_ASR:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton.""" """Wraps a single OmniSenseVoiceSmall model instance for Triton."""
@@ -163,6 +161,7 @@ class _Token2Wav_ASR:
self.codec_decoder = CosyVoice2( self.codec_decoder = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
) )
@batch @batch
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray): def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
""" """
@@ -236,7 +235,6 @@ class _Token2Wav_ASR:
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1) rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts} return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}