This commit is contained in:
lyuxiang.lx
2026-01-29 10:29:22 +00:00
parent f26cde56df
commit 84e41729ea
4 changed files with 20 additions and 13 deletions

View File

@@ -367,8 +367,11 @@ class Qwen2LM(TransformerLM):
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
if 'speech_token' not in batch:
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
@@ -686,8 +689,12 @@ class CosyVoice3LM(Qwen2LM):
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
if 'speech_token' not in batch:
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device)