send streaming as args

This commit is contained in:
lyuxiang.lx
2025-05-27 13:47:12 +08:00
parent 54d21b40f0
commit cbfed4a9ee
5 changed files with 14 additions and 19 deletions

View File

@@ -258,9 +258,6 @@ class CosyVoice2Model(CosyVoiceModel):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
# NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
self.flow.encoder.streaming = True
self.flow.decoder.estimator.streaming = True
self.hift = hift
self.fp16 = fp16
self.trt_concurrent = trt_concurrent
@@ -290,7 +287,7 @@ class CosyVoice2Model(CosyVoiceModel):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
@@ -299,6 +296,7 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
@@ -356,6 +354,7 @@ class CosyVoice2Model(CosyVoiceModel):
embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid,
stream=stream,
finalize=False)
token_offset += this_token_hop_len
yield {'tts_speech': this_tts_speech.cpu()}