mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
refine code
This commit is contained in:
@@ -103,35 +103,29 @@ class CosyVoiceModel:
|
|||||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
||||||
if isinstance(text, Generator):
|
if isinstance(text, Generator):
|
||||||
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
|
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
|
||||||
for i in self.llm.inference_bistream(text=text,
|
token_generator = self.llm.inference_bistream(text=text,
|
||||||
|
prompt_text=prompt_text.to(self.device),
|
||||||
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=llm_embedding.to(self.device))
|
||||||
|
else:
|
||||||
|
token_generator = self.llm.inference(text=text.to(self.device),
|
||||||
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_text=prompt_text.to(self.device),
|
prompt_text=prompt_text.to(self.device),
|
||||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
embedding=llm_embedding.to(self.device)):
|
embedding=llm_embedding.to(self.device),
|
||||||
if i in self.silent_tokens:
|
uuid=uuid)
|
||||||
cur_silent_token_num += 1
|
for i in token_generator:
|
||||||
if cur_silent_token_num > max_silent_token_num:
|
if i in self.silent_tokens:
|
||||||
continue
|
cur_silent_token_num += 1
|
||||||
else:
|
if cur_silent_token_num > max_silent_token_num:
|
||||||
cur_silent_token_num = 0
|
continue
|
||||||
self.tts_speech_token_dict[uuid].append(i)
|
else:
|
||||||
else:
|
cur_silent_token_num = 0
|
||||||
for i in self.llm.inference(text=text.to(self.device),
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
|
||||||
prompt_text=prompt_text.to(self.device),
|
|
||||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
||||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
||||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
||||||
embedding=llm_embedding.to(self.device),
|
|
||||||
uuid=uuid):
|
|
||||||
if i in self.silent_tokens:
|
|
||||||
cur_silent_token_num += 1
|
|
||||||
if cur_silent_token_num > max_silent_token_num:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
cur_silent_token_num = 0
|
|
||||||
self.tts_speech_token_dict[uuid].append(i)
|
|
||||||
self.llm_end_dict[uuid] = True
|
self.llm_end_dict[uuid] = True
|
||||||
|
|
||||||
def vc_job(self, source_speech_token, uuid):
|
def vc_job(self, source_speech_token, uuid):
|
||||||
|
|||||||
Reference in New Issue
Block a user