This commit is contained in:
lyuxiang.lx
2025-08-22 14:24:27 +08:00
parent dc96e4c984
commit 6b5eef62cc

View File

@@ -403,11 +403,6 @@ class CosyVoice3Model(CosyVoice2Model):
self.flow.half()
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# hift cache
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
# speech fade in out
self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
@@ -428,26 +423,17 @@ class CosyVoice3Model(CosyVoice2Model):
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
# append mel cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
self.hift_cache_dict[uuid]['mel'] = tts_mel
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
if speed != 1.0:
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
return tts_speech