From c9acce1482fd20fa57f6dc242476fbaeccafcc50 Mon Sep 17 00:00:00 2001 From: boji123 Date: Fri, 20 Sep 2024 12:35:44 +0800 Subject: [PATCH 1/3] [debug] support flow cache, for sharper tts_mel output --- cosyvoice/cli/model.py | 10 ++++++++-- cosyvoice/flow/flow.py | 12 ++++++++---- cosyvoice/flow/flow_matching.py | 21 ++++++++++++++++++--- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 489978d..542ea77 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -50,6 +50,7 @@ class CosyVoiceModel: # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} + self.flow_cache_dict = {} self.mel_overlap_dict = {} self.hift_cache_dict = {} @@ -92,13 +93,17 @@ class CosyVoiceModel: self.llm_end_dict[uuid] = True def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): - tts_mel = self.flow.inference(token=token.to(self.device), + tts_mel, flow_cache = self.flow.inference(token=token.to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device), prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), prompt_feat=prompt_feat.to(self.device), prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device)) + embedding=embedding.to(self.device), + required_cache_size=self.mel_overlap_len, + flow_cache=self.flow_cache_dict[uuid]) + self.flow_cache_dict[uuid] = flow_cache + # mel overlap fade in out if self.mel_overlap_dict[uuid] is not None: tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) @@ -137,6 +142,7 @@ class CosyVoiceModel: this_uuid = str(uuid.uuid1()) with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + self.flow_cache_dict[this_uuid] = None self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p.start() diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 0fa6407..50d96f7 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -109,7 +109,9 @@ class MaskedDiffWithXvec(torch.nn.Module): prompt_token_len, prompt_feat, prompt_feat_len, - embedding): + embedding, + required_cache_size=0, + flow_cache=None): assert token.shape[0] == 1 # xvec projection embedding = F.normalize(embedding, dim=1) @@ -133,13 +135,15 @@ class MaskedDiffWithXvec(torch.nn.Module): conds = conds.transpose(1, 2) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - feat = self.decoder( + feat, flow_cache = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, cond=conds, - n_timesteps=10 + n_timesteps=10, + required_cache_size=required_cache_size, + flow_cache=flow_cache ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 - return feat + return feat, flow_cache diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 92afee2..4b1503b 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM): self.estimator = estimator @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None): """Forward diffusion Args: @@ -50,11 +50,26 @@ class ConditionalCFM(BASECFM): sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ - z = torch.randn_like(mu) * temperature + + if flow_cache is not None: + z_cache = flow_cache[0] + mu_cache = flow_cache[1] + z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature + z = torch.cat((z_cache, z), dim=2) # [B, 80, T] + mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T] + else: + z = torch.randn_like(mu) * temperature + + next_cache_start = max(z.size(2) - required_cache_size, 0) + flow_cache = [ + z[..., next_cache_start:], + mu[..., next_cache_start:] + ] + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache def solve_euler(self, x, t_span, mu, mask, spks, cond): """ From 8130abb5ea67c77c1da7be2c3ab0e0f7707a8c52 Mon Sep 17 00:00:00 2001 From: boji123 Date: Sun, 29 Sep 2024 19:12:30 +0800 Subject: [PATCH 2/3] [debug] handle cache with prompt --- cosyvoice/flow/flow.py | 1 + cosyvoice/flow/flow_matching.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 50d96f7..e430b83 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -141,6 +141,7 @@ class MaskedDiffWithXvec(torch.nn.Module): spks=embedding, cond=conds, n_timesteps=10, + prompt_len=mel_len1, required_cache_size=required_cache_size, flow_cache=flow_cache ) diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 4b1503b..83dc971 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM): self.estimator = estimator @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None): """Forward diffusion Args: @@ -62,8 +62,8 @@ class ConditionalCFM(BASECFM): next_cache_start = max(z.size(2) - required_cache_size, 0) flow_cache = [ - z[..., next_cache_start:], - mu[..., next_cache_start:] + torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2), + torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2) ] t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) From a4db3db8ed0186c78253e79ece55c1b6f727f500 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 16 Oct 2024 15:24:47 +0800 Subject: [PATCH 3/3] update flow cache --- cosyvoice/cli/model.py | 28 +++++++++++++++------------- cosyvoice/flow/flow.py | 4 +--- cosyvoice/flow/flow_matching.py | 25 ++++++++++--------------- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index f6f4808..1fcc31f 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -52,8 +52,8 @@ class CosyVoiceModel: # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} - self.flow_cache_dict = {} self.mel_overlap_dict = {} + self.flow_cache_dict = {} self.hift_cache_dict = {} def load(self, llm_model, flow_model, hift_model): @@ -102,18 +102,17 @@ class CosyVoiceModel: def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): tts_mel, flow_cache = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - required_cache_size=self.mel_overlap_len, - flow_cache=self.flow_cache_dict[uuid]) + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + flow_cache=self.flow_cache_dict[uuid]) self.flow_cache_dict[uuid] = flow_cache # mel overlap fade in out - if self.mel_overlap_dict[uuid] is not None: + if self.mel_overlap_dict[uuid].shape[2] != 0: tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) # append hift cache if self.hift_cache_dict[uuid] is not None: @@ -150,8 +149,9 @@ class CosyVoiceModel: this_uuid = str(uuid.uuid1()) with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False - self.flow_cache_dict[this_uuid] = None - self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None + self.hift_cache_dict[this_uuid] = None + self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) + self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p.start() if stream is True: @@ -207,7 +207,9 @@ class CosyVoiceModel: this_uuid = str(uuid.uuid1()) with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True - self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None + self.hift_cache_dict[this_uuid] = None + self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) + self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) if stream is True: token_hop_len = self.token_min_hop_len while True: diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index e430b83..eea705b 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -110,8 +110,7 @@ class MaskedDiffWithXvec(torch.nn.Module): prompt_feat, prompt_feat_len, embedding, - required_cache_size=0, - flow_cache=None): + flow_cache): assert token.shape[0] == 1 # xvec projection embedding = F.normalize(embedding, dim=1) @@ -142,7 +141,6 @@ class MaskedDiffWithXvec(torch.nn.Module): cond=conds, n_timesteps=10, prompt_len=mel_len1, - required_cache_size=required_cache_size, flow_cache=flow_cache ) feat = feat[:, :, mel_len1:] diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 83dc971..d011304 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM): self.estimator = estimator @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): """Forward diffusion Args: @@ -51,20 +51,15 @@ class ConditionalCFM(BASECFM): shape: (batch_size, n_feats, mel_timesteps) """ - if flow_cache is not None: - z_cache = flow_cache[0] - mu_cache = flow_cache[1] - z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature - z = torch.cat((z_cache, z), dim=2) # [B, 80, T] - mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T] - else: - z = torch.randn_like(mu) * temperature - - next_cache_start = max(z.size(2) - required_cache_size, 0) - flow_cache = [ - torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2), - torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2) - ] + z = torch.randn_like(mu) * temperature + cache_size = flow_cache.shape[2] + # fix prompt and overlap part mu and z + if cache_size != 0: + z[:, :, :cache_size] = flow_cache[:, :, :, 0] + mu[:, :, :cache_size] = flow_cache[:, :, :, 1] + z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) + mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) + flow_cache = torch.stack([z_cache, mu_cache], dim=-1) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine':