diff --git a/examples/grpo/cosyvoice2/infer_dataset.py b/examples/grpo/cosyvoice2/infer_dataset.py index f0d22d7..f72cd77 100644 --- a/examples/grpo/cosyvoice2/infer_dataset.py +++ b/examples/grpo/cosyvoice2/infer_dataset.py @@ -53,7 +53,7 @@ except RuntimeError: pass -TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501 +TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501 def audio_decode_cosyvoice2( diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index 6bce5cc..1c6c423 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -464,7 +464,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] - for i, item in enumerate(batch): + for item in batch: generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) @@ -496,7 +496,7 @@ if __name__ == "__main__": data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - for epoch in range(args.warmup): + for _ in range(args.warmup): start_time = time.time() for batch in data_loader: ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index e3eac2f..326fb0e 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -512,7 +512,7 @@ def main(args): )) else: outputs = [] - for i, chat in enumerate(batch["chat_list"]): + for chat in batch["chat_list"]: payload = { "model": args.openai_model_name, "messages": chat, diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index 9c4a2fb..7cfb6f9 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -13,7 +13,7 @@ import soundfile as sf def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] prompt_speech_tokens_list, prompt_text_list = [], [] - for i, item in enumerate(batch): + for item in batch: generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio)