mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update
This commit is contained in:
@@ -127,7 +127,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
embedding = self.spk_embed_affine_layer(embedding)
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
embedding = embedding.unsqueeze(1)
|
embedding = embedding.unsqueeze(1)
|
||||||
|
|
||||||
# 3. eos and task_id
|
# 3. sos and task_id
|
||||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
|
||||||
@@ -300,7 +300,7 @@ class Qwen2LM(TransformerLM):
|
|||||||
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
||||||
self.vllm_output_queue = {}
|
self.vllm_output_queue = {}
|
||||||
|
|
||||||
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
|
||||||
lm_target, lm_input = [], []
|
lm_target, lm_input = [], []
|
||||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||||
@@ -311,7 +311,7 @@ class Qwen2LM(TransformerLM):
|
|||||||
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
||||||
this_lm_target, this_lm_input = [], []
|
this_lm_target, this_lm_input = [], []
|
||||||
this_lm_target.append(IGNORE_ID)
|
this_lm_target.append(IGNORE_ID)
|
||||||
this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1))
|
this_lm_input.append(sos_emb.squeeze(dim=0))
|
||||||
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
||||||
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
||||||
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
||||||
@@ -327,14 +327,13 @@ class Qwen2LM(TransformerLM):
|
|||||||
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
|
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
|
||||||
this_lm_target.append(self.eos_token)
|
this_lm_target.append(self.eos_token)
|
||||||
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
|
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
|
||||||
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
|
this_lm_input.append(task_id_emb.squeeze(dim=0))
|
||||||
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
|
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
|
||||||
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
||||||
# unistream sequence
|
# unistream sequence
|
||||||
else:
|
else:
|
||||||
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
||||||
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i],
|
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
|
||||||
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
|
|
||||||
lm_target.append(this_lm_target)
|
lm_target.append(this_lm_target)
|
||||||
lm_input.append(this_lm_input)
|
lm_input.append(this_lm_input)
|
||||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||||
@@ -362,11 +361,15 @@ class Qwen2LM(TransformerLM):
|
|||||||
# 1. encode text_token
|
# 1. encode text_token
|
||||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||||
|
|
||||||
|
# 3. sos and task_id
|
||||||
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
|
||||||
# 2. encode speech_token
|
# 2. encode speech_token
|
||||||
speech_token_emb = self.speech_embedding(speech_token)
|
speech_token_emb = self.speech_embedding(speech_token)
|
||||||
|
|
||||||
# 3. prepare llm_input/target
|
# 3. prepare llm_input/target
|
||||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
|
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len)
|
||||||
lm_target = lm_target.to(device)
|
lm_target = lm_target.to(device)
|
||||||
|
|
||||||
# 4. run lm forward
|
# 4. run lm forward
|
||||||
@@ -391,6 +394,10 @@ class Qwen2LM(TransformerLM):
|
|||||||
# 1. encode text_token
|
# 1. encode text_token
|
||||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||||
|
|
||||||
|
# 3. sos and task_id
|
||||||
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
|
||||||
# 2. encode speech_token
|
# 2. encode speech_token
|
||||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||||
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
|
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
|
||||||
@@ -400,8 +407,8 @@ class Qwen2LM(TransformerLM):
|
|||||||
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
|
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
|
||||||
|
|
||||||
# 3. prepare llm_input/target
|
# 3. prepare llm_input/target
|
||||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
|
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
|
||||||
speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
||||||
lm_target = lm_target.to(device)
|
lm_target = lm_target.to(device)
|
||||||
|
|
||||||
# 4. run lm forward
|
# 4. run lm forward
|
||||||
@@ -650,6 +657,43 @@ class CosyVoice3LM(Qwen2LM):
|
|||||||
self.stop_token_ids = [speech_token_size + i for i in range(4)]
|
self.stop_token_ids = [speech_token_size + i for i in range(4)]
|
||||||
self.vllm_output_queue = {}
|
self.vllm_output_queue = {}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: dict,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text: (B, L, D)
|
||||||
|
text_lengths: (B,)
|
||||||
|
audio: (B, T, N) or (B, T)
|
||||||
|
audio_lengths: (B,)
|
||||||
|
"""
|
||||||
|
text_token = batch['text_token'].to(device)
|
||||||
|
text_token_len = batch['text_token_len'].to(device)
|
||||||
|
speech_token = batch['speech_token'].to(device)
|
||||||
|
speech_token_len = batch['speech_token_len'].to(device)
|
||||||
|
|
||||||
|
# 1. encode text_token
|
||||||
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||||
|
|
||||||
|
# 3. sos and task_id
|
||||||
|
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
# 2. encode speech_token
|
||||||
|
speech_token_emb = self.speech_embedding(speech_token)
|
||||||
|
|
||||||
|
# 3. prepare llm_input/target
|
||||||
|
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len)
|
||||||
|
lm_target = lm_target.to(device)
|
||||||
|
|
||||||
|
# 4. run lm forward
|
||||||
|
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||||
|
logits = self.llm_decoder(lm_output)
|
||||||
|
loss = self.criterion_ce(logits, lm_target.to(device))
|
||||||
|
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
||||||
|
return {'loss': loss, 'acc': acc}
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(
|
def inference(
|
||||||
|
|||||||
Reference in New Issue
Block a user