add silent_token

This commit is contained in:
lyuxiang.lx
2025-12-30 09:18:17 +00:00
parent dd5cdb6ebf
commit cfa1c115b2

View File

@@ -60,6 +60,7 @@ class CosyVoiceModel:
self.mel_overlap_dict = {} self.mel_overlap_dict = {}
self.flow_cache_dict = {} self.flow_cache_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.silent_tokens = []
def load(self, llm_model, flow_model, hift_model): def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -98,6 +99,7 @@ class CosyVoiceModel:
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
cur_silent_token_num, max_silent_token_num = 0, 5
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False): with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator): if isinstance(text, Generator):
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!' assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
@@ -107,6 +109,12 @@ class CosyVoiceModel:
prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)): embedding=llm_embedding.to(self.device)):
if i in self.silent_tokens:
cur_silent_token_num += 1
if cur_silent_token_num > max_silent_token_num:
continue
else:
cur_silent_token_num = 0
self.tts_speech_token_dict[uuid].append(i) self.tts_speech_token_dict[uuid].append(i)
else: else:
for i in self.llm.inference(text=text.to(self.device), for i in self.llm.inference(text=text.to(self.device),
@@ -117,6 +125,12 @@ class CosyVoiceModel:
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device), embedding=llm_embedding.to(self.device),
uuid=uuid): uuid=uuid):
if i in self.silent_tokens:
cur_silent_token_num += 1
if cur_silent_token_num > max_silent_token_num:
continue
else:
cur_silent_token_num = 0
self.tts_speech_token_dict[uuid].append(i) self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True self.llm_end_dict[uuid] = True
@@ -260,6 +274,7 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict = {} self.tts_speech_token_dict = {}
self.llm_end_dict = {} self.llm_end_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.silent_tokens = []
def load_jit(self, flow_encoder_model): def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
@@ -401,6 +416,8 @@ class CosyVoice3Model(CosyVoice2Model):
self.tts_speech_token_dict = {} self.tts_speech_token_dict = {}
self.llm_end_dict = {} self.llm_end_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
# FSQ silent token
self.silent_tokens = [28, 29]
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16): with torch.cuda.amp.autocast(self.fp16):