mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
send streaming as args
This commit is contained in:
@@ -258,9 +258,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
# NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
|
||||
self.flow.encoder.streaming = True
|
||||
self.flow.decoder.estimator.streaming = True
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
self.trt_concurrent = trt_concurrent
|
||||
@@ -290,7 +287,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
@@ -299,6 +296,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
streaming=stream,
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append hift cache
|
||||
@@ -356,6 +354,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
embedding=flow_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=this_uuid,
|
||||
stream=stream,
|
||||
finalize=False)
|
||||
token_offset += this_token_hop_len
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
|
||||
Reference in New Issue
Block a user