fix white space

This commit is contained in:
root
2025-10-09 14:49:22 +08:00
parent 807bb6ee0b
commit 8811e9f33a
7 changed files with 27 additions and 33 deletions

View File

@@ -36,7 +36,7 @@ Stage `0` converts raw JSONL files into the parquet format expected by veRL:
```bash ```bash
bash run.sh 0 0 bash run.sh 0 0
``` ```
Create two JSONL files—`train.jsonl` and `test.jsonl`. Create two JSONL files—`train.jsonl` and `test.jsonl`.
The script will then generate two Parquet files: The script will then generate two Parquet files:
``` ```
@@ -111,7 +111,7 @@ bash run.sh 5 5
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository. The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
> [!TIP] > [!TIP]
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. > However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
## Results ## Results

View File

@@ -33,7 +33,7 @@ fi
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint" log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
python3 pretrained_to_huggingface.py \ python3 pretrained_to_huggingface.py \
--pretrained-cosyvoice2-path $model_scope_model_path \ --pretrained-cosyvoice2-path $model_scope_model_path \
--save-path $sft_model_path --save-path $sft_model_path
@@ -61,7 +61,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: start token2wav asr server for reward function" log "stage 1: start token2wav asr server for reward function"
python3 token2wav_asr_server.py --number-of-devices 8 python3 token2wav_asr_server.py --number-of-devices 8
fi fi
exp_name=official_llm_aishell3_grpo exp_name=official_llm_aishell3_grpo
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@@ -125,7 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--backend fsdp \ --backend fsdp \
--local_dir $llm_path/actor \ --local_dir $llm_path/actor \
--target_dir $llm_path/merged_hf_model || exit 1 --target_dir $llm_path/merged_hf_model || exit 1
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Test the model" log "stage 4: Test the model"

View File

@@ -254,7 +254,7 @@ class TritonPythonModel:
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request # Create and execute inference request
inference_request = pb_utils.InferenceRequest( inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit', model_name='token2wav_dit',
@@ -362,7 +362,7 @@ class TritonPythonModel:
chunk_index += 1 chunk_index += 1
else: else:
break break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))

View File

@@ -62,7 +62,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
# Create a SHA-256 hash of the byte string # Create a SHA-256 hash of the byte string
hasher = hashlib.sha256() hasher = hashlib.sha256()
hasher.update(tensor_bytes) hasher.update(tensor_bytes)
return hasher.hexdigest() return hasher.hexdigest()
class TritonPythonModel: class TritonPythonModel:
@@ -111,9 +111,9 @@ class TritonPythonModel:
target_speech_tokens = target_speech_tokens.squeeze().tolist() target_speech_tokens = target_speech_tokens.squeeze().tolist()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request_id = request.request_id() request_id = request.request_id()
wav_array = pb_utils.get_input_tensor_by_name( wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy() request, "reference_wav").as_numpy()

View File

@@ -133,7 +133,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
option.intra_op_num_threads = 1 option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
providers=["CPUExecutionProvider"]) providers=["CPUExecutionProvider"])
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
gpu="l20" gpu="l20"
@@ -253,7 +252,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
prompt_speech_tokens_list.append(speech_tokens_i) prompt_speech_tokens_list.append(speech_tokens_i)
return prompt_speech_tokens_list return prompt_speech_tokens_list
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
spk_emb_for_flow = [] spk_emb_for_flow = []
for audio in prompt_audios_list: for audio in prompt_audios_list:
@@ -263,11 +262,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
spk_emb = self.forward_spk_embedding(spk_feat) spk_emb = self.forward_spk_embedding(spk_feat)
spk_emb_for_flow.append(spk_emb) spk_emb_for_flow.append(spk_emb)
spk_emb_for_flow = torch.tensor(spk_emb_for_flow) spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
if self.dtype != torch.float32: if self.dtype != torch.float32:
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
return spk_emb_for_flow return spk_emb_for_flow
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
prompt_mels_for_flow = [] prompt_mels_for_flow = []
prompt_mels_lens_for_flow = [] prompt_mels_lens_for_flow = []
@@ -283,7 +282,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
return prompt_mels_for_flow, prompt_mels_lens_for_flow return prompt_mels_for_flow, prompt_mels_lens_for_flow
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
batch_size = prompt_mels_for_flow.shape[0] batch_size = prompt_mels_for_flow.shape[0]
flow_inputs = [] flow_inputs = []
@@ -318,28 +317,24 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
): ):
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
return generated_wavs return generated_wavs
def prepare_prompt_audio( def prepare_prompt_audio(
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
): ):
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
@@ -365,7 +360,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
@torch.inference_mode() @torch.inference_mode()
def forward_streaming( def forward_streaming(
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
): ):
if speaker_id not in self.speaker_cache: if speaker_id not in self.speaker_cache:
assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio is not None, "prompt_audio is required for new speaker"
assert prompt_audio_sample_rate == 16000 assert prompt_audio_sample_rate == 16000
@@ -384,7 +379,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
if request_id not in self.streaming_flow_cache: if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict( self.hift_cache_dict[request_id] = dict(
mel = torch.zeros(1, 80, 0, device='cuda'), mel = torch.zeros(1, 80, 0, device='cuda'),
source = torch.zeros(1, 1, 0, device='cuda'), source = torch.zeros(1, 1, 0, device='cuda'),
speech = torch.zeros(1, 0, device='cuda'), speech = torch.zeros(1, 0, device='cuda'),
) )
@@ -445,7 +440,7 @@ def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float() audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio) prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id']) ids.append(item['id'])
@@ -473,20 +468,20 @@ if __name__ == "__main__":
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for epoch in range(args.warmup): for epoch in range(args.warmup):
start_time = time.time() start_time = time.time()
for batch in data_loader: for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
for id, wav in zip(ids, generated_wavs): for id, wav in zip(ids, generated_wavs):
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
end_time = time.time() end_time = time.time()
epoch_time = end_time - start_time epoch_time = end_time - start_time
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

View File

@@ -365,7 +365,6 @@ def main(args):
runner = None runner = None
else: else:
raise ValueError(f"Unsupported backend: {args.backend}") raise ValueError(f"Unsupported backend: {args.backend}")
if 'Step-Audio-2-mini' in args.token2wav_path: if 'Step-Audio-2-mini' in args.token2wav_path:
from token2wav_dit import CosyVoice2_Token2Wav from token2wav_dit import CosyVoice2_Token2Wav
else: else:

View File

@@ -14,7 +14,7 @@ def collate_fn(batch):
prompt_speech_tokens_list, prompt_text_list = [], [] prompt_speech_tokens_list, prompt_text_list = [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float() audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio) prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id']) ids.append(item['id'])
@@ -37,7 +37,7 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
if not os.path.exists(args.output_dir): if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
@@ -46,7 +46,7 @@ if __name__ == "__main__":
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
CHUNK_SIZE = 25 CHUNK_SIZE = 25
token_frame_rate = 25 token_frame_rate = 25
OVERLAP_SIZE = 0 OVERLAP_SIZE = 0
@@ -68,7 +68,7 @@ if __name__ == "__main__":
semantic_token_ids_arr, token_offset = [], 0 semantic_token_ids_arr, token_offset = [], 0
flow_prompt_speech_token_len = len(prompt_speech_tokens) flow_prompt_speech_token_len = len(prompt_speech_tokens)
buffer = generated_speech_tokens buffer = generated_speech_tokens
output_wavs = [] output_wavs = []
chunk_index = 0 chunk_index = 0
@@ -97,7 +97,7 @@ if __name__ == "__main__":
output_wavs[i] = wav.cpu().numpy().squeeze() output_wavs[i] = wav.cpu().numpy().squeeze()
audios = output_wavs audios = output_wavs
reconstructed_audio = np.concatenate(audios) reconstructed_audio = np.concatenate(audios)
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")