mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
[debug] support flow cache, for sharper tts_mel output
This commit is contained in:
@@ -50,6 +50,7 @@ class CosyVoiceModel:
|
|||||||
# dict used to store session related variable
|
# dict used to store session related variable
|
||||||
self.tts_speech_token_dict = {}
|
self.tts_speech_token_dict = {}
|
||||||
self.llm_end_dict = {}
|
self.llm_end_dict = {}
|
||||||
|
self.flow_cache_dict = {}
|
||||||
self.mel_overlap_dict = {}
|
self.mel_overlap_dict = {}
|
||||||
self.hift_cache_dict = {}
|
self.hift_cache_dict = {}
|
||||||
|
|
||||||
@@ -92,13 +93,17 @@ class CosyVoiceModel:
|
|||||||
self.llm_end_dict[uuid] = True
|
self.llm_end_dict[uuid] = True
|
||||||
|
|
||||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||||
tts_mel = self.flow.inference(token=token.to(self.device),
|
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
||||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_token=prompt_token.to(self.device),
|
prompt_token=prompt_token.to(self.device),
|
||||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_feat=prompt_feat.to(self.device),
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
embedding=embedding.to(self.device))
|
embedding=embedding.to(self.device),
|
||||||
|
required_cache_size=self.mel_overlap_len,
|
||||||
|
flow_cache=self.flow_cache_dict[uuid])
|
||||||
|
self.flow_cache_dict[uuid] = flow_cache
|
||||||
|
|
||||||
# mel overlap fade in out
|
# mel overlap fade in out
|
||||||
if self.mel_overlap_dict[uuid] is not None:
|
if self.mel_overlap_dict[uuid] is not None:
|
||||||
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
||||||
@@ -137,6 +142,7 @@ class CosyVoiceModel:
|
|||||||
this_uuid = str(uuid.uuid1())
|
this_uuid = str(uuid.uuid1())
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||||
|
self.flow_cache_dict[this_uuid] = None
|
||||||
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
||||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||||
p.start()
|
p.start()
|
||||||
|
|||||||
@@ -109,7 +109,9 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
prompt_token_len,
|
prompt_token_len,
|
||||||
prompt_feat,
|
prompt_feat,
|
||||||
prompt_feat_len,
|
prompt_feat_len,
|
||||||
embedding):
|
embedding,
|
||||||
|
required_cache_size=0,
|
||||||
|
flow_cache=None):
|
||||||
assert token.shape[0] == 1
|
assert token.shape[0] == 1
|
||||||
# xvec projection
|
# xvec projection
|
||||||
embedding = F.normalize(embedding, dim=1)
|
embedding = F.normalize(embedding, dim=1)
|
||||||
@@ -133,13 +135,15 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||||
feat = self.decoder(
|
feat, flow_cache = self.decoder(
|
||||||
mu=h.transpose(1, 2).contiguous(),
|
mu=h.transpose(1, 2).contiguous(),
|
||||||
mask=mask.unsqueeze(1),
|
mask=mask.unsqueeze(1),
|
||||||
spks=embedding,
|
spks=embedding,
|
||||||
cond=conds,
|
cond=conds,
|
||||||
n_timesteps=10
|
n_timesteps=10,
|
||||||
|
required_cache_size=required_cache_size,
|
||||||
|
flow_cache=flow_cache
|
||||||
)
|
)
|
||||||
feat = feat[:, :, mel_len1:]
|
feat = feat[:, :, mel_len1:]
|
||||||
assert feat.shape[2] == mel_len2
|
assert feat.shape[2] == mel_len2
|
||||||
return feat
|
return feat, flow_cache
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
|
|||||||
self.estimator = estimator
|
self.estimator = estimator
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None):
|
||||||
"""Forward diffusion
|
"""Forward diffusion
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -50,11 +50,26 @@ class ConditionalCFM(BASECFM):
|
|||||||
sample: generated mel-spectrogram
|
sample: generated mel-spectrogram
|
||||||
shape: (batch_size, n_feats, mel_timesteps)
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
z = torch.randn_like(mu) * temperature
|
|
||||||
|
if flow_cache is not None:
|
||||||
|
z_cache = flow_cache[0]
|
||||||
|
mu_cache = flow_cache[1]
|
||||||
|
z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature
|
||||||
|
z = torch.cat((z_cache, z), dim=2) # [B, 80, T]
|
||||||
|
mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T]
|
||||||
|
else:
|
||||||
|
z = torch.randn_like(mu) * temperature
|
||||||
|
|
||||||
|
next_cache_start = max(z.size(2) - required_cache_size, 0)
|
||||||
|
flow_cache = [
|
||||||
|
z[..., next_cache_start:],
|
||||||
|
mu[..., next_cache_start:]
|
||||||
|
]
|
||||||
|
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||||
if self.t_scheduler == 'cosine':
|
if self.t_scheduler == 'cosine':
|
||||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
||||||
|
|
||||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user