From 72b89a52fbf19d53fd86664a7e99d35769c278ae Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 26 Sep 2024 11:53:10 +0800 Subject: [PATCH] update vc/tts code --- cosyvoice/cli/model.py | 77 +++++++++++++++++++++++++----- cosyvoice/flow/flow.py | 2 +- cosyvoice/flow/length_regulator.py | 9 ++-- cosyvoice/tokenizer/tokenizer.py | 6 +-- cosyvoice/utils/common.py | 1 + 5 files changed, 74 insertions(+), 21 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 51fa85b..1272b61 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -35,7 +35,7 @@ class CosyVoiceModel: self.token_max_hop_len = 200 self.token_overlap_len = 20 # mel fade in out - self.mel_overlap_len = 34 + self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256) self.mel_window = np.hamming(2 * self.mel_overlap_len) # hift cache self.mel_cache_len = 20 @@ -54,9 +54,10 @@ class CosyVoiceModel: self.hift_cache_dict = {} def load(self, llm_model, flow_model, hift_model): - self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) - self.llm.to(self.device).eval() - self.llm.half() + if self.llm is not None: + self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) + self.llm.to(self.device).eval() + self.llm.half() self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) self.flow.to(self.device).eval() self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) @@ -131,11 +132,11 @@ class CosyVoiceModel: tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) return tts_speech - def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), - prompt_text=torch.zeros(1, 0, dtype=torch.int32), - llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), - flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), - prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): + def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), + prompt_text=torch.zeros(1, 0, dtype=torch.int32), + llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), + flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), + prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) with self.lock: @@ -148,7 +149,8 @@ class CosyVoiceModel: while True: time.sleep(0.1) if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: - this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1) + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ + .unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -164,7 +166,7 @@ class CosyVoiceModel: break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None - this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1) + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -175,7 +177,58 @@ class CosyVoiceModel: else: # deal with all tokens p.join() - this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1) + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True, + speed=speed) + yield {'tts_speech': this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + self.mel_overlap_dict.pop(this_uuid) + self.hift_cache_dict.pop(this_uuid) + + def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs): + # this_uuid is used to track variables related to this inference thread + this_uuid = str(uuid.uuid1()) + with self.lock: + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True + self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None + if stream is True: + token_hop_len = self.token_min_hop_len + while True: + if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ + .unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=False) + yield {'tts_speech': this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] + # increase token_hop_len for better speech quality + token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: + break + # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True) + yield {'tts_speech': this_tts_speech.cpu()} + else: + # deal with all tokens + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index c63b352..0fa6407 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -125,7 +125,7 @@ class MaskedDiffWithXvec(torch.nn.Module): h, h_lengths = self.encoder(token, token_len) h = self.encoder_proj(h) mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256) - h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2) + h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate) # get conditions conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device) diff --git a/cosyvoice/flow/length_regulator.py b/cosyvoice/flow/length_regulator.py index 26cb994..2cae42f 100644 --- a/cosyvoice/flow/length_regulator.py +++ b/cosyvoice/flow/length_regulator.py @@ -49,13 +49,14 @@ class InterpolateRegulator(nn.Module): olens = ylens return out * mask, olens - def inference(self, x1, x2, mel_len1, mel_len2): + def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel # x in (B, T, D) if x2.shape[1] > 40: - x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear') - x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear') - x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear') + x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') + x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, + mode='linear') + x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) else: x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') diff --git a/cosyvoice/tokenizer/tokenizer.py b/cosyvoice/tokenizer/tokenizer.py index 131fc4e..caecf26 100644 --- a/cosyvoice/tokenizer/tokenizer.py +++ b/cosyvoice/tokenizer/tokenizer.py @@ -1,9 +1,7 @@ import base64 import os -import string -from dataclasses import dataclass, field -from functools import cached_property, lru_cache -from typing import Dict, List, Optional, Tuple +from functools import lru_cache +from typing import Optional from whisper.tokenizer import Tokenizer import tiktoken diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 832cbc7..9b91ab5 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -145,6 +145,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window): fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] return fade_in_mel.to(device) + def set_all_random_seed(seed): random.seed(seed) np.random.seed(seed)