mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
update
This commit is contained in:
@@ -207,7 +207,7 @@ class CosyVoice3(CosyVoice):
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
||||
assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
|
||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||
configs['feat_extractor'],
|
||||
'{}/campplus.onnx'.format(model_dir),
|
||||
|
||||
@@ -56,8 +56,9 @@ class TransformerLM(torch.nn.Module):
|
||||
)
|
||||
|
||||
# 2. build speech token language model related modules
|
||||
self.sos_eos = 0
|
||||
self.sos = 0
|
||||
self.task_id = 1
|
||||
self.eos_token = self.speech_token_size
|
||||
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||
self.llm = llm
|
||||
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
||||
@@ -85,10 +86,10 @@ class TransformerLM(torch.nn.Module):
|
||||
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
||||
def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
||||
lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
||||
for i in range(len(text_token))]
|
||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
||||
@@ -127,14 +128,14 @@ class TransformerLM(torch.nn.Module):
|
||||
embedding = embedding.unsqueeze(1)
|
||||
|
||||
# 3. eos and task_id
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 4. encode speech_token
|
||||
speech_token = self.speech_embedding(speech_token)
|
||||
|
||||
# 5. unpad and pad
|
||||
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
||||
lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
|
||||
task_id_emb, speech_token, speech_token_len)
|
||||
|
||||
# 6. run lm forward
|
||||
@@ -193,13 +194,13 @@ class TransformerLM(torch.nn.Module):
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
@@ -215,11 +216,8 @@ class TransformerLM(torch.nn.Module):
|
||||
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
||||
device=lm_input.device)).to(torch.bool))
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
# force continue decode first token
|
||||
if i == 0:
|
||||
logp[:, self.speech_token_size] = -float('inf')
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
if top_ids == self.eos_token:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
@@ -276,9 +274,10 @@ class Qwen2LM(TransformerLM):
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
# 2. build speech token language model related modules
|
||||
self.sos_eos = 0
|
||||
self.sos = 0
|
||||
self.task_id = 1
|
||||
self.fill_token = 2
|
||||
self.eos_token = speech_token_size
|
||||
self.fill_token = speech_token_size + 2
|
||||
|
||||
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||
self.llm = llm
|
||||
@@ -312,7 +311,7 @@ class Qwen2LM(TransformerLM):
|
||||
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
||||
this_lm_target, this_lm_input = [], []
|
||||
this_lm_target.append(IGNORE_ID)
|
||||
this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
|
||||
this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1))
|
||||
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
||||
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
||||
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
||||
@@ -320,21 +319,21 @@ class Qwen2LM(TransformerLM):
|
||||
assert len(this_speech_token) == self.mix_ratio[1]
|
||||
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
|
||||
this_lm_target += this_speech_token
|
||||
this_lm_target.append(self.speech_token_size + 2)
|
||||
this_lm_target.append(self.fill_token)
|
||||
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
|
||||
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
|
||||
else:
|
||||
this_lm_target += [-1] * len(this_text_token)
|
||||
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
|
||||
this_lm_target.append(self.speech_token_size)
|
||||
this_lm_target.append(self.eos_token)
|
||||
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
|
||||
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
|
||||
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
|
||||
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
||||
# unistream sequence
|
||||
else:
|
||||
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
|
||||
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
|
||||
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
||||
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i],
|
||||
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
|
||||
lm_target.append(this_lm_target)
|
||||
lm_input.append(this_lm_input)
|
||||
@@ -445,13 +444,13 @@ class Qwen2LM(TransformerLM):
|
||||
text = self.llm.model.model.embed_tokens(text)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
@@ -501,10 +500,8 @@ class Qwen2LM(TransformerLM):
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
if top_ids in self.stop_token_ids:
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
@@ -526,13 +523,13 @@ class Qwen2LM(TransformerLM):
|
||||
|
||||
device = prompt_text.device
|
||||
# 1. prepare input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb], dim=1)
|
||||
lm_input = torch.concat([sos_emb], dim=1)
|
||||
|
||||
# 2. iterate text
|
||||
out_tokens = []
|
||||
@@ -554,12 +551,12 @@ class Qwen2LM(TransformerLM):
|
||||
break
|
||||
# no prompt_speech_token_emb remain, can decode some speech token
|
||||
if prompt_speech_token_emb.size(1) == 0:
|
||||
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||
if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||
logging.info('get fill token, need to append more text token')
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
||||
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
||||
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
|
||||
if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
|
||||
lm_input = lm_input_text
|
||||
else:
|
||||
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
||||
@@ -574,16 +571,16 @@ class Qwen2LM(TransformerLM):
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||
top_ids = self.speech_token_size + 2
|
||||
top_ids = self.fill_token
|
||||
next_fill_index += (self.mix_ratio[1] + 1)
|
||||
else:
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
if top_ids == self.fill_token:
|
||||
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
if top_ids == self.fill_token:
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
@@ -602,7 +599,7 @@ class Qwen2LM(TransformerLM):
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size:
|
||||
if top_ids == self.eos_token:
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
@@ -628,10 +625,10 @@ class CosyVoice3LM(Qwen2LM):
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
# 2. build speech token language model related modules
|
||||
self.sos = 0
|
||||
self.eos = 1
|
||||
self.task_id = 2
|
||||
self.fill_token = 3
|
||||
self.sos = speech_token_size + 0
|
||||
self.eos_token = speech_token_size + 1
|
||||
self.task_id = speech_token_size + 2
|
||||
self.fill_token = speech_token_size + 3
|
||||
|
||||
self.llm = llm
|
||||
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
|
||||
@@ -649,6 +646,11 @@ class CosyVoice3LM(Qwen2LM):
|
||||
self.sampling = sampling
|
||||
self.mix_ratio = mix_ratio
|
||||
|
||||
# 5. vllm related
|
||||
self.stop_token_ids = [speech_token_size + i for i in range(4)]
|
||||
self.vllm_output_queue = {}
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
@@ -670,13 +672,13 @@ class CosyVoice3LM(Qwen2LM):
|
||||
text = self.llm.model.model.embed_tokens(text)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.speech_embedding.weight[self.speech_token_size + self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.speech_embedding.weight[self.speech_token_size + self.task_id].reshape(1, 1, -1)
|
||||
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
|
||||
@@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
||||
RelPositionMultiHeadedAttention)
|
||||
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
||||
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
||||
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
|
||||
from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
|
||||
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
|
||||
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||
|
||||
|
||||
COSYVOICE_ACTIVATION_CLASSES = {
|
||||
@@ -80,6 +80,6 @@ def get_model_type(configs):
|
||||
return CosyVoiceModel
|
||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoice2Model
|
||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
|
||||
return CosyVoice2Model
|
||||
if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
|
||||
return CosyVoice3Model
|
||||
raise TypeError('No valid model type found!')
|
||||
|
||||
Reference in New Issue
Block a user