mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
update
This commit is contained in:
@@ -164,6 +164,9 @@ class TransformerLM(torch.nn.Module):
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
if self.fp16 is True:
|
||||
embedding = embedding.half()
|
||||
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
text_len += prompt_text_len
|
||||
@@ -178,7 +181,7 @@ class TransformerLM(torch.nn.Module):
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
embedding = embedding.unsqueeze(dim=1)
|
||||
else:
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
@@ -310,7 +313,7 @@ class Qwen2LM(torch.nn.Module):
|
||||
text = self.llm.model.model.embed_tokens(text)
|
||||
|
||||
# 2. encode embedding
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
|
||||
Reference in New Issue
Block a user