mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add speech fade in out
This commit is contained in:
@@ -40,6 +40,8 @@ class CosyVoiceModel:
|
|||||||
# hift cache
|
# hift cache
|
||||||
self.mel_cache_len = 20
|
self.mel_cache_len = 20
|
||||||
self.source_cache_len = int(self.mel_cache_len * 256)
|
self.source_cache_len = int(self.mel_cache_len * 256)
|
||||||
|
# speech fade in out
|
||||||
|
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||||
# rtf and decoding related
|
# rtf and decoding related
|
||||||
self.stream_scale_factor = 1
|
self.stream_scale_factor = 1
|
||||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||||
@@ -50,7 +52,6 @@ class CosyVoiceModel:
|
|||||||
self.llm_end_dict = {}
|
self.llm_end_dict = {}
|
||||||
self.mel_overlap_dict = {}
|
self.mel_overlap_dict = {}
|
||||||
self.hift_cache_dict = {}
|
self.hift_cache_dict = {}
|
||||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
|
||||||
|
|
||||||
def load(self, llm_model, flow_model, hift_model):
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
||||||
@@ -117,10 +118,9 @@ class CosyVoiceModel:
|
|||||||
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
||||||
if self.hift_cache_dict[uuid] is not None:
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
self.hift_cache_dict[uuid] = {
|
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||||
'mel': tts_mel[:, :, -self.mel_cache_len:],
|
'source': tts_source[:, :, -self.source_cache_len:],
|
||||||
'source': tts_source[:, :, -self.source_cache_len:],
|
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
|
||||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||||
else:
|
else:
|
||||||
if speed != 1.0:
|
if speed != 1.0:
|
||||||
|
|||||||
@@ -139,7 +139,6 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
|
|||||||
device = fade_in_mel.device
|
device = fade_in_mel.device
|
||||||
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
||||||
mel_overlap_len = int(window.shape[0] / 2)
|
mel_overlap_len = int(window.shape[0] / 2)
|
||||||
|
|
||||||
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
||||||
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||||
return fade_in_mel.to(device)
|
return fade_in_mel.to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user