diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index a3710f4..d09403e 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -25,6 +25,7 @@ class CosyVoice: def __init__(self, model_dir, load_jit=True, load_onnx=False): instruct = True if '-Instruct' in model_dir else False + vc = True if '-VC' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) @@ -36,6 +37,7 @@ class CosyVoice: '{}/speech_tokenizer_v1.onnx'.format(model_dir), '{}/spk2info.pt'.format(model_dir), instruct, + vc, configs['allowed_special']) self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) self.model.load('{}/llm.pt'.format(model_dir), @@ -58,7 +60,7 @@ class CosyVoice: model_input = self.frontend.frontend_sft(i, spk_id) start_time = time.time() logging.info('synthesis text {}'.format(i)) - for model_output in self.model.inference(**model_input, stream=stream, speed=speed): + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): speech_len = model_output['tts_speech'].shape[1] / 22050 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output @@ -70,7 +72,7 @@ class CosyVoice: model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) start_time = time.time() logging.info('synthesis text {}'.format(i)) - for model_output in self.model.inference(**model_input, stream=stream, speed=speed): + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): speech_len = model_output['tts_speech'].shape[1] / 22050 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output @@ -83,7 +85,7 @@ class CosyVoice: model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) start_time = time.time() logging.info('synthesis text {}'.format(i)) - for model_output in self.model.inference(**model_input, stream=stream, speed=speed): + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): speech_len = model_output['tts_speech'].shape[1] / 22050 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output @@ -97,8 +99,17 @@ class CosyVoice: model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) start_time = time.time() logging.info('synthesis text {}'.format(i)) - for model_output in self.model.inference(**model_input, stream=stream, speed=speed): + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): speech_len = model_output['tts_speech'].shape[1] / 22050 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output start_time = time.time() + + def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): + model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k) + start_time = time.time() + for model_output in self.model.vc(**model_input, stream=stream, speed=speed): + speech_len = model_output['tts_speech'].shape[1] / 22050 + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 69557d8..d9d97ee 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -42,6 +42,7 @@ class CosyVoiceFrontEnd: speech_tokenizer_model: str, spk2info: str = '', instruct: bool = False, + vc: bool = False, allowed_special: str = 'all'): self.tokenizer = get_tokenizer() self.feat_extractor = feat_extractor @@ -55,7 +56,10 @@ class CosyVoiceFrontEnd: "CPUExecutionProvider"]) if os.path.exists(spk2info): self.spk2info = torch.load(spk2info, map_location=self.device) + else: + self.spk2info = {} self.instruct = instruct + self.vc = vc self.allowed_special = allowed_special self.inflect_parser = inflect.engine() self.use_ttsfrd = use_ttsfrd @@ -172,3 +176,15 @@ class CosyVoiceFrontEnd: model_input['prompt_text'] = instruct_text_token model_input['prompt_text_len'] = instruct_text_token_len return model_input + + def frontend_vc(self, source_speech_16k, prompt_speech_16k): + prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k) + prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) + prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050) + embedding = self._extract_spk_embedding(prompt_speech_16k) + source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k) + model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len, + 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len, + 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len, + 'flow_embedding': embedding} + return model_input diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 48835b8..c63b352 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -124,7 +124,7 @@ class MaskedDiffWithXvec(torch.nn.Module): # text encode h, h_lengths = self.encoder(token, token_len) h = self.encoder_proj(h) - mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256) + mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256) h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2) # get conditions @@ -132,7 +132,6 @@ class MaskedDiffWithXvec(torch.nn.Module): conds[:, :mel_len1] = prompt_feat conds = conds.transpose(1, 2) - # mask = (~make_pad_mask(feat_len)).to(h) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) feat = self.decoder( mu=h.transpose(1, 2).contiguous(), diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index dcd087a..eb377f1 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -206,7 +206,7 @@ class TransformerLM(torch.nn.Module): if top_ids == self.speech_token_size: break # in stream mode, yield token one by one - yield torch.tensor([[top_ids]], dtype=torch.int64, device=device) + yield top_ids out_tokens.append(top_ids) offset += lm_input.size(1) lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) diff --git a/cosyvoice/tokenizer/tokenizer.py b/cosyvoice/tokenizer/tokenizer.py index ff5adbf..131fc4e 100644 --- a/cosyvoice/tokenizer/tokenizer.py +++ b/cosyvoice/tokenizer/tokenizer.py @@ -4,6 +4,7 @@ import string from dataclasses import dataclass, field from functools import cached_property, lru_cache from typing import Dict, List, Optional, Tuple +from whisper.tokenizer import Tokenizer import tiktoken @@ -165,208 +166,6 @@ TTS_Vocal_Token = { } -@dataclass -class Tokenizer: - """A thin wrapper around `tiktoken` providing quick access to special tokens""" - - encoding: tiktoken.Encoding - num_languages: int - language: Optional[str] = None - task: Optional[str] = None - sot_sequence: Tuple[int] = () - special_tokens: Dict[str, int] = field(default_factory=dict) - - def __post_init__(self): - for special in self.encoding.special_tokens_set: - special_token = self.encoding.encode_single_token(special) - self.special_tokens[special] = special_token - - sot: int = self.special_tokens["<|startoftranscript|>"] - translate: int = self.special_tokens["<|translate|>"] - transcribe: int = self.special_tokens["<|transcribe|>"] - - langs = tuple(LANGUAGES.keys())[: self.num_languages] - sot_sequence = [sot] - if self.language is not None: - sot_sequence.append(sot + 1 + langs.index(self.language)) - if self.task is not None: - task_token: int = transcribe if self.task == "transcribe" else translate - sot_sequence.append(task_token) - - self.sot_sequence = tuple(sot_sequence) - - def encode(self, text, **kwargs): - return self.encoding.encode(text, **kwargs) - - def decode(self, token_ids: List[int], **kwargs) -> str: - token_ids = [t for t in token_ids if t < self.timestamp_begin] - return self.encoding.decode(token_ids, **kwargs) - - def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str: - """ - Timestamp tokens are above other special tokens' id range and are ignored by `decode()`. - This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". - """ - return self.encoding.decode(token_ids, **kwargs) - - def get_vocab_size(self) -> int: - return self.encoding.n_vocab - - @cached_property - def eot(self) -> int: - return self.encoding.eot_token - - @cached_property - def transcribe(self) -> int: - return self.special_tokens["<|transcribe|>"] - - @cached_property - def translate(self) -> int: - return self.special_tokens["<|translate|>"] - - @cached_property - def sot(self) -> int: - return self.special_tokens["<|startoftranscript|>"] - - @cached_property - def sot_lm(self) -> int: - return self.special_tokens["<|startoflm|>"] - - @cached_property - def sot_prev(self) -> int: - return self.special_tokens["<|startofprev|>"] - - @cached_property - def no_speech(self) -> int: - return self.special_tokens["<|nospeech|>"] - - @cached_property - def no_timestamps(self) -> int: - return self.special_tokens["<|notimestamps|>"] - - @cached_property - def timestamp_begin(self) -> int: - return self.special_tokens["<|0.00|>"] - - @cached_property - def language_token(self) -> int: - """Returns the token id corresponding to the value of the `language` field""" - if self.language is None: - raise ValueError("This tokenizer does not have language token configured") - - return self.to_language_token(self.language) - - def to_language_token(self, language): - if token := self.special_tokens.get(f"<|{language}|>", None): - return token - - raise KeyError(f"Language {language} not found in tokenizer.") - - @cached_property - def all_language_tokens(self) -> Tuple[int]: - result = [] - for token, token_id in self.special_tokens.items(): - if token.strip("<|>") in LANGUAGES: - result.append(token_id) - return tuple(result)[: self.num_languages] - - @cached_property - def all_language_codes(self) -> Tuple[str]: - return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) - - @cached_property - def sot_sequence_including_notimestamps(self) -> Tuple[int]: - return tuple(list(self.sot_sequence) + [self.no_timestamps]) - - @cached_property - def non_speech_tokens(self) -> Tuple[int]: - """ - Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech - annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. - - - ♪♪♪ - - ( SPEAKING FOREIGN LANGUAGE ) - - [DAVID] Hey there, - - keeping basic punctuations like commas, periods, question marks, exclamation points, etc. - """ - symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') - symbols += ( - "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() - ) - - # symbols that may be a single token or multiple tokens depending on the tokenizer. - # In case they're multiple tokens, suppress the first token, which is safe because: - # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress - # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. - miscellaneous = set("♩♪♫♬♭♮♯") - assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) - - # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word - result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} - for symbol in symbols + list(miscellaneous): - for tokens in [ - self.encoding.encode(symbol), - self.encoding.encode(" " + symbol), - ]: - if len(tokens) == 1 or symbol in miscellaneous: - result.add(tokens[0]) - - return tuple(sorted(result)) - - def split_to_word_tokens(self, tokens: List[int]): - if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: - # These languages don't typically use spaces, so it is difficult to split words - # without morpheme analysis. Here, we instead split words at any - # position where the tokens are decoded as valid unicode points - return self.split_tokens_on_unicode(tokens) - - return self.split_tokens_on_spaces(tokens) - - def split_tokens_on_unicode(self, tokens: List[int]): - decoded_full = self.decode_with_timestamps(tokens) - replacement_char = "\ufffd" - - words = [] - word_tokens = [] - current_tokens = [] - unicode_offset = 0 - - for token in tokens: - current_tokens.append(token) - decoded = self.decode_with_timestamps(current_tokens) - - if ( - replacement_char not in decoded - or decoded_full[unicode_offset + decoded.index(replacement_char)] - == replacement_char - ): - words.append(decoded) - word_tokens.append(current_tokens) - current_tokens = [] - unicode_offset += len(decoded) - - return words, word_tokens - - def split_tokens_on_spaces(self, tokens: List[int]): - subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) - words = [] - word_tokens = [] - - for subword, subword_tokens in zip(subwords, subword_tokens_list): - special = subword_tokens[0] >= self.eot - with_space = subword.startswith(" ") - punctuation = subword.strip() in string.punctuation - if special or with_space or punctuation or len(words) == 0: - words.append(subword) - word_tokens.append(subword_tokens) - else: - words[-1] = words[-1] + subword - word_tokens[-1].extend(subword_tokens) - - return words, word_tokens - - @lru_cache(maxsize=None) def get_encoding(name: str = "gpt2", num_languages: int = 99): vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 2a87657..832cbc7 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -15,8 +15,10 @@ # Modified from ESPnet(https://github.com/espnet/espnet) """Unility functions for Transformer.""" +import random from typing import List +import numpy as np import torch IGNORE_ID = -1 @@ -142,3 +144,9 @@ def fade_in_out(fade_in_mel, fade_out_mel, window): fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] return fade_in_mel.to(device) + +def set_all_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/webui.py b/webui.py index 2f41196..0233a94 100644 --- a/webui.py +++ b/webui.py @@ -24,6 +24,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR)) from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.utils.file_utils import load_wav, logging +from cosyvoice.utils.common import set_all_random_seed inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制'] instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮', @@ -42,13 +43,6 @@ def generate_seed(): } -def set_all_random_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - def postprocess(speech, top_db=60, hop_length=220, win_length=440): speech, _ = librosa.effects.trim( speech, top_db=top_db,