From ffa28e3bbda47952e758481b154e37173f0bc47d Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Sun, 29 Sep 2024 10:35:10 +0800 Subject: [PATCH] update token args --- cosyvoice/cli/model.py | 9 +++------ cosyvoice/llm/llm.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index ea0ec4a..489978d 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -31,8 +31,8 @@ class CosyVoiceModel: self.llm = llm self.flow = flow self.hift = hift - self.token_min_hop_len = 100 - self.token_max_hop_len = 200 + self.token_min_hop_len = 2 * self.flow.input_frame_rate + self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_overlap_len = 20 # mel fade in out self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256) @@ -87,10 +87,7 @@ class CosyVoiceModel: prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), - embedding=llm_embedding.to(self.device).half(), - sampling=25, - max_token_text_ratio=30, - min_token_text_ratio=3): + embedding=llm_embedding.to(self.device).half()): self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index eb377f1..00e4af0 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -197,7 +197,7 @@ class TransformerLM(torch.nn.Module): offset = 0 att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device) for i in range(max_len): - y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, + y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache, att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))