diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 6bc3b31..694104b 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -60,6 +60,7 @@ class CosyVoiceModel: self.mel_overlap_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} + self.silent_tokens = [] def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) @@ -98,6 +99,7 @@ class CosyVoiceModel: return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): + cur_silent_token_num, max_silent_token_num = 0, 5 with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False): if isinstance(text, Generator): assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!' @@ -107,6 +109,12 @@ class CosyVoiceModel: prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), embedding=llm_embedding.to(self.device)): + if i in self.silent_tokens: + cur_silent_token_num += 1 + if cur_silent_token_num > max_silent_token_num: + continue + else: + cur_silent_token_num = 0 self.tts_speech_token_dict[uuid].append(i) else: for i in self.llm.inference(text=text.to(self.device), @@ -117,6 +125,12 @@ class CosyVoiceModel: prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), embedding=llm_embedding.to(self.device), uuid=uuid): + if i in self.silent_tokens: + cur_silent_token_num += 1 + if cur_silent_token_num > max_silent_token_num: + continue + else: + cur_silent_token_num = 0 self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True @@ -260,6 +274,7 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict = {} self.llm_end_dict = {} self.hift_cache_dict = {} + self.silent_tokens = [] def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) @@ -401,6 +416,8 @@ class CosyVoice3Model(CosyVoice2Model): self.tts_speech_token_dict = {} self.llm_end_dict = {} self.hift_cache_dict = {} + # FSQ silent token + self.silent_tokens = [28, 29] def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): with torch.cuda.amp.autocast(self.fp16):