This commit is contained in:
lyuxiang.lx
2025-08-21 21:03:58 +08:00
parent 70991d7327
commit dc96e4c984
3 changed files with 45 additions and 43 deletions

View File

@@ -207,7 +207,7 @@ class CosyVoice3(CosyVoice):
raise ValueError('{} not found!'.format(hyper_yaml_path)) raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f: with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir) assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'], configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir), '{}/campplus.onnx'.format(model_dir),

View File

@@ -56,8 +56,9 @@ class TransformerLM(torch.nn.Module):
) )
# 2. build speech token language model related modules # 2. build speech token language model related modules
self.sos_eos = 0 self.sos = 0
self.task_id = 1 self.task_id = 1
self.eos_token = self.speech_token_size
self.llm_embedding = torch.nn.Embedding(2, llm_input_size) self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1) self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
@@ -85,10 +86,10 @@ class TransformerLM(torch.nn.Module):
encoder_out = self.text_encoder_affine_layer(encoder_out) encoder_out = self.text_encoder_affine_layer(encoder_out)
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))] for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
@@ -127,14 +128,14 @@ class TransformerLM(torch.nn.Module):
embedding = embedding.unsqueeze(1) embedding = embedding.unsqueeze(1)
# 3. eos and task_id # 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 4. encode speech_token # 4. encode speech_token
speech_token = self.speech_embedding(speech_token) speech_token = self.speech_embedding(speech_token)
# 5. unpad and pad # 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len) task_id_emb, speech_token, speech_token_len)
# 6. run lm forward # 6. run lm forward
@@ -193,13 +194,13 @@ class TransformerLM(torch.nn.Module):
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
# 3. concat llm_input # 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0: if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else: else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length # 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio) min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -215,11 +216,8 @@ class TransformerLM(torch.nn.Module):
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool)) device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
# force continue decode first token
if i == 0:
logp[:, self.speech_token_size] = -float('inf')
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size: if top_ids == self.eos_token:
break break
# in stream mode, yield token one by one # in stream mode, yield token one by one
yield top_ids yield top_ids
@@ -276,9 +274,10 @@ class Qwen2LM(TransformerLM):
self.llm_output_size = llm_output_size self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size self.speech_token_size = speech_token_size
# 2. build speech token language model related modules # 2. build speech token language model related modules
self.sos_eos = 0 self.sos = 0
self.task_id = 1 self.task_id = 1
self.fill_token = 2 self.eos_token = speech_token_size
self.fill_token = speech_token_size + 2
self.llm_embedding = torch.nn.Embedding(2, llm_input_size) self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm self.llm = llm
@@ -312,7 +311,7 @@ class Qwen2LM(TransformerLM):
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]: if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [], [] this_lm_target, this_lm_input = [], []
this_lm_target.append(IGNORE_ID) this_lm_target.append(IGNORE_ID)
this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1)) this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1))
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()): for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist() this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist() this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -320,21 +319,21 @@ class Qwen2LM(TransformerLM):
assert len(this_speech_token) == self.mix_ratio[1] assert len(this_speech_token) == self.mix_ratio[1]
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1) this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
this_lm_target += this_speech_token this_lm_target += this_speech_token
this_lm_target.append(self.speech_token_size + 2) this_lm_target.append(self.fill_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]]) this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]]) this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
else: else:
this_lm_target += [-1] * len(this_text_token) this_lm_target += [-1] * len(this_text_token)
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist() this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
this_lm_target.append(self.speech_token_size) this_lm_target.append(self.eos_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:]) this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1)) this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:]) this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0) this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
# unistream sequence # unistream sequence
else: else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size]) this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i], this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i],
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0) self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target) lm_target.append(this_lm_target)
lm_input.append(this_lm_input) lm_input.append(this_lm_input)
@@ -445,13 +444,13 @@ class Qwen2LM(TransformerLM):
text = self.llm.model.model.embed_tokens(text) text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input # 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0: if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else: else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length # 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio) min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -501,10 +500,8 @@ class Qwen2LM(TransformerLM):
cache=cache) cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size: if top_ids in self.stop_token_ids:
break break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one # in stream mode, yield token one by one
yield top_ids yield top_ids
out_tokens.append(top_ids) out_tokens.append(top_ids)
@@ -526,13 +523,13 @@ class Qwen2LM(TransformerLM):
device = prompt_text.device device = prompt_text.device
# 1. prepare input # 1. prepare input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0: if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else: else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb], dim=1) lm_input = torch.concat([sos_emb], dim=1)
# 2. iterate text # 2. iterate text
out_tokens = [] out_tokens = []
@@ -554,12 +551,12 @@ class Qwen2LM(TransformerLM):
break break
# no prompt_speech_token_emb remain, can decode some speech token # no prompt_speech_token_emb remain, can decode some speech token
if prompt_speech_token_emb.size(1) == 0: if prompt_speech_token_emb.size(1) == 0:
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1): if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
logging.info('get fill token, need to append more text token') logging.info('get fill token, need to append more text token')
if text_cache.size(1) >= self.mix_ratio[0]: if text_cache.size(1) >= self.mix_ratio[0]:
lm_input_text = text_cache[:, :self.mix_ratio[0]] lm_input_text = text_cache[:, :self.mix_ratio[0]]
logging.info('append {} text token'.format(lm_input_text.size(1))) logging.info('append {} text token'.format(lm_input_text.size(1)))
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2: if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
lm_input = lm_input_text lm_input = lm_input_text
else: else:
lm_input = torch.concat([lm_input, lm_input_text], dim=1) lm_input = torch.concat([lm_input, lm_input_text], dim=1)
@@ -574,16 +571,16 @@ class Qwen2LM(TransformerLM):
cache=cache) cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index: if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2 top_ids = self.fill_token
next_fill_index += (self.mix_ratio[1] + 1) next_fill_index += (self.mix_ratio[1] + 1)
else: else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item() top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
if top_ids == self.speech_token_size + 2: if top_ids == self.fill_token:
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index)) logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
out_tokens.append(top_ids) out_tokens.append(top_ids)
if top_ids >= self.speech_token_size: if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size + 2: if top_ids == self.fill_token:
break break
else: else:
raise ValueError('should not get token {}'.format(top_ids)) raise ValueError('should not get token {}'.format(top_ids))
@@ -602,7 +599,7 @@ class Qwen2LM(TransformerLM):
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
out_tokens.append(top_ids) out_tokens.append(top_ids)
if top_ids >= self.speech_token_size: if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size: if top_ids == self.eos_token:
break break
else: else:
raise ValueError('should not get token {}'.format(top_ids)) raise ValueError('should not get token {}'.format(top_ids))
@@ -628,10 +625,10 @@ class CosyVoice3LM(Qwen2LM):
self.llm_output_size = llm_output_size self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size self.speech_token_size = speech_token_size
# 2. build speech token language model related modules # 2. build speech token language model related modules
self.sos = 0 self.sos = speech_token_size + 0
self.eos = 1 self.eos_token = speech_token_size + 1
self.task_id = 2 self.task_id = speech_token_size + 2
self.fill_token = 3 self.fill_token = speech_token_size + 3
self.llm = llm self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False) self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
@@ -649,6 +646,11 @@ class CosyVoice3LM(Qwen2LM):
self.sampling = sampling self.sampling = sampling
self.mix_ratio = mix_ratio self.mix_ratio = mix_ratio
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(4)]
self.vllm_output_queue = {}
@torch.inference_mode() @torch.inference_mode()
def inference( def inference(
self, self,
@@ -670,13 +672,13 @@ class CosyVoice3LM(Qwen2LM):
text = self.llm.model.model.embed_tokens(text) text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input # 3. concat llm_input
sos_eos_emb = self.speech_embedding.weight[self.speech_token_size + self.sos].reshape(1, 1, -1) sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.speech_token_size + self.task_id].reshape(1, 1, -1) task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0: if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else: else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length # 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio) min_len = int((text_len - prompt_text_len) * min_token_text_ratio)

View File

@@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention) RelPositionMultiHeadedAttention)
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
from cosyvoice.llm.llm import TransformerLM, Qwen2LM from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
COSYVOICE_ACTIVATION_CLASSES = { COSYVOICE_ACTIVATION_CLASSES = {
@@ -80,6 +80,6 @@ def get_model_type(configs):
return CosyVoiceModel return CosyVoiceModel
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoice2Model return CosyVoice2Model
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator): if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
return CosyVoice2Model return CosyVoice3Model
raise TypeError('No valid model type found!') raise TypeError('No valid model type found!')