This commit is contained in:
yuekaiz
2025-10-09 15:18:09 +08:00
parent 33aee03ed5
commit a224be6117
4 changed files with 5 additions and 5 deletions

View File

@@ -53,7 +53,7 @@ except RuntimeError:
pass pass
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501 TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
def audio_decode_cosyvoice2( def audio_decode_cosyvoice2(

View File

@@ -464,7 +464,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
def collate_fn(batch): def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for i, item in enumerate(batch): for item in batch:
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float() audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio) prompt_audios_list.append(audio)
@@ -496,7 +496,7 @@ if __name__ == "__main__":
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for epoch in range(args.warmup): for _ in range(args.warmup):
start_time = time.time() start_time = time.time()
for batch in data_loader: for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch

View File

@@ -512,7 +512,7 @@ def main(args):
)) ))
else: else:
outputs = [] outputs = []
for i, chat in enumerate(batch["chat_list"]): for chat in batch["chat_list"]:
payload = { payload = {
"model": args.openai_model_name, "model": args.openai_model_name,
"messages": chat, "messages": chat,

View File

@@ -13,7 +13,7 @@ import soundfile as sf
def collate_fn(batch): def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
prompt_speech_tokens_list, prompt_text_list = [], [] prompt_speech_tokens_list, prompt_text_list = [], []
for i, item in enumerate(batch): for item in batch:
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float() audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio) prompt_audios_list.append(audio)