mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update token args
This commit is contained in:
@@ -31,8 +31,8 @@ class CosyVoiceModel:
|
|||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
self.token_min_hop_len = 100
|
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
||||||
self.token_max_hop_len = 200
|
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
||||||
self.token_overlap_len = 20
|
self.token_overlap_len = 20
|
||||||
# mel fade in out
|
# mel fade in out
|
||||||
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
||||||
@@ -87,10 +87,7 @@ class CosyVoiceModel:
|
|||||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
embedding=llm_embedding.to(self.device).half(),
|
embedding=llm_embedding.to(self.device).half()):
|
||||||
sampling=25,
|
|
||||||
max_token_text_ratio=30,
|
|
||||||
min_token_text_ratio=3):
|
|
||||||
self.tts_speech_token_dict[uuid].append(i)
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
self.llm_end_dict[uuid] = True
|
self.llm_end_dict[uuid] = True
|
||||||
|
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
offset = 0
|
offset = 0
|
||||||
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
||||||
for i in range(max_len):
|
for i in range(max_len):
|
||||||
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1,
|
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
||||||
att_cache=att_cache, cnn_cache=cnn_cache,
|
att_cache=att_cache, cnn_cache=cnn_cache,
|
||||||
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
||||||
device=lm_input.device)).to(torch.bool))
|
device=lm_input.device)).to(torch.bool))
|
||||||
|
|||||||
Reference in New Issue
Block a user