mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update model inference
This commit is contained in:
@@ -174,7 +174,7 @@ class TransformerLM(torch.nn.Module):
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
embedding = embedding.unsqueeze(dim=1)
|
||||
else:
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
@@ -182,7 +182,7 @@ class TransformerLM(torch.nn.Module):
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
|
||||
Reference in New Issue
Block a user