update flow cache

This commit is contained in:
lyuxiang.lx
2024-10-16 15:24:47 +08:00
parent ace734def8
commit a4db3db8ed
3 changed files with 26 additions and 31 deletions

View File

@@ -52,8 +52,8 @@ class CosyVoiceModel:
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.flow_cache_dict = {}
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
def load(self, llm_model, flow_model, hift_model):
@@ -102,18 +102,17 @@ class CosyVoiceModel:
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
required_cache_size=self.mel_overlap_len,
flow_cache=self.flow_cache_dict[uuid])
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
flow_cache=self.flow_cache_dict[uuid])
self.flow_cache_dict[uuid] = flow_cache
# mel overlap fade in out
if self.mel_overlap_dict[uuid] is not None:
if self.mel_overlap_dict[uuid].shape[2] != 0:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
# append hift cache
if self.hift_cache_dict[uuid] is not None:
@@ -150,8 +149,9 @@ class CosyVoiceModel:
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.flow_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
if stream is True:
@@ -207,7 +207,9 @@ class CosyVoiceModel:
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
if stream is True:
token_hop_len = self.token_min_hop_len
while True:

View File

@@ -110,8 +110,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
prompt_feat,
prompt_feat_len,
embedding,
required_cache_size=0,
flow_cache=None):
flow_cache):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
@@ -142,7 +141,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
required_cache_size=required_cache_size,
flow_cache=flow_cache
)
feat = feat[:, :, mel_len1:]

View File

@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
"""Forward diffusion
Args:
@@ -51,20 +51,15 @@ class ConditionalCFM(BASECFM):
shape: (batch_size, n_feats, mel_timesteps)
"""
if flow_cache is not None:
z_cache = flow_cache[0]
mu_cache = flow_cache[1]
z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature
z = torch.cat((z_cache, z), dim=2) # [B, 80, T]
mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T]
else:
z = torch.randn_like(mu) * temperature
next_cache_start = max(z.size(2) - required_cache_size, 0)
flow_cache = [
torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2),
torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2)
]
z = torch.randn_like(mu) * temperature
cache_size = flow_cache.shape[2]
# fix prompt and overlap part mu and z
if cache_size != 0:
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':