From b9ddcba5fd59d13436033b0ff2425f8ff335d657 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Mon, 30 Dec 2024 16:41:57 +0800 Subject: [PATCH] add some instruction and assert --- README.md | 9 +++---- cosyvoice/cli/cosyvoice.py | 42 ++++++++++++++++--------------- cosyvoice/cli/frontend.py | 46 ++++++++-------------------------- cosyvoice/cli/model.py | 2 ++ cosyvoice/utils/class_utils.py | 12 +++++++++ webui.py | 4 +-- 6 files changed, 52 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 4a23e44..077cc12 100644 --- a/README.md +++ b/README.md @@ -121,13 +121,10 @@ We strongly recommend using `CosyVoice2-0.5B` for better performance. For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model. For sft inference, please use `CosyVoice-300M-SFT` model. For instruct inference, please use `CosyVoice-300M-Instruct` model. -First, add `third_party/Matcha-TTS` to your `PYTHONPATH`. - -``` sh -export PYTHONPATH=third_party/Matcha-TTS -``` ``` python +import sys +sys.path.append('third_party/Matcha-TTS') from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.utils.file_utils import load_wav import torchaudio @@ -161,7 +158,7 @@ print(cosyvoice.list_available_spks()) for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)): torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) -cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference +cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') # or change to pretrained_models/CosyVoice-300M-25Hz for 25Hz inference # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000) for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 3f848a8..420d761 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -20,23 +20,24 @@ import torch from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.utils.file_utils import logging +from cosyvoice.utils.class_utils import get_model_type class CosyVoice: def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True): - instruct = True if '-Instruct' in model_dir else False + self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: configs = load_hyperpyyaml(f) + assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir) self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], configs['feat_extractor'], '{}/campplus.onnx'.format(model_dir), '{}/speech_tokenizer_v1.onnx'.format(model_dir), '{}/spk2info.pt'.format(model_dir), - instruct, configs['allowed_special']) self.sample_rate = configs['sample_rate'] if torch.cuda.is_available() is False and (fp16 is True or load_jit is True): @@ -85,8 +86,6 @@ class CosyVoice: start_time = time.time() def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): - if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel): - raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir)) for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate) start_time = time.time() @@ -98,8 +97,8 @@ class CosyVoice: start_time = time.time() def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True): - assert isinstance(self.model, CosyVoiceModel) - if self.frontend.instruct is False: + assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!' + if self.instruct is False: raise ValueError('{} do not support instruct inference'.format(self.model_dir)) instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend) for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): @@ -112,18 +111,6 @@ class CosyVoice: yield model_output start_time = time.time() - def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): - assert isinstance(self.model, CosyVoice2Model) - for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate) - start_time = time.time() - logging.info('synthesis text {}'.format(i)) - for model_output in self.model.tts(**model_input, stream=stream, speed=speed): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() - def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate) start_time = time.time() @@ -137,18 +124,18 @@ class CosyVoice: class CosyVoice2(CosyVoice): def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False): - instruct = True if '-Instruct' in model_dir else False + self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) + assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir) self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], configs['feat_extractor'], '{}/campplus.onnx'.format(model_dir), '{}/speech_tokenizer_v2.onnx'.format(model_dir), '{}/spk2info.pt'.format(model_dir), - instruct, configs['allowed_special']) self.sample_rate = configs['sample_rate'] if torch.cuda.is_available() is False and load_jit is True: @@ -168,3 +155,18 @@ class CosyVoice2(CosyVoice): if load_trt: self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir)) del configs + + def inference_instruct(self, *args, **kwargs): + raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!') + + def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): + assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!' + for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): + model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 14cb0cb..ab59a93 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -42,7 +42,6 @@ class CosyVoiceFrontEnd: campplus_model: str, speech_tokenizer_model: str, spk2info: str = '', - instruct: bool = False, allowed_special: str = 'all'): self.tokenizer = get_tokenizer() self.feat_extractor = feat_extractor @@ -58,9 +57,7 @@ class CosyVoiceFrontEnd: self.spk2info = torch.load(spk2info, map_location=self.device) else: self.spk2info = {} - self.instruct = instruct self.allowed_special = allowed_special - self.inflect_parser = inflect.engine() self.use_ttsfrd = use_ttsfrd if self.use_ttsfrd: self.frd = ttsfrd.TtsFrontendEngine() @@ -71,6 +68,7 @@ class CosyVoiceFrontEnd: else: self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) self.en_tn_model = EnNormalizer() + self.inflect_parser = inflect.engine() def _extract_text_token(self, text): text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) @@ -111,15 +109,11 @@ class CosyVoiceFrontEnd: if text_frontend is False: return [text] if split is True else text text = text.strip() - # When generating text that contains only punctuation marks or whitespace characters - # - Returning empty texts ensures consistent processing logic. - if is_only_punctuation(text): - return [] - if contains_chinese(text): - if self.use_ttsfrd: - texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]] - text = ''.join(texts) - else: + if self.use_ttsfrd: + texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]] + text = ''.join(texts) + else: + if contains_chinese(text): text = self.zh_tn_model.normalize(text) text = text.replace("\n", "") text = replace_blank(text) @@ -130,18 +124,13 @@ class CosyVoiceFrontEnd: text = re.sub(r'[,,、]+$', '。', text) texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False)) - else: - if self.use_ttsfrd: - texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]] - text = ''.join(texts) else: text = self.en_tn_model.normalize(text) text = spell_out_number(text, self.inflect_parser) texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False)) - if split is False: - return text - return texts + texts = [i for i in texts if not is_only_punctuation(i)] + return texts if split is True else text def frontend_sft(self, tts_text, spk_id): tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) @@ -188,22 +177,9 @@ class CosyVoiceFrontEnd: return model_input def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate): - tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) - prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>') - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) - speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample) - speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) - if resample_rate == 24000: - # cosyvoice2, force speech_feat % speech_token = 2 - token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1]) - speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len - speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len - embedding = self._extract_spk_embedding(prompt_speech_16k) - model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, - 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len, - 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, - 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, - 'llm_embedding': embedding, 'flow_embedding': embedding} + model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate) + del model_input['llm_prompt_speech_token'] + del model_input['llm_prompt_speech_token_len'] return model_input def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate): diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index f2369f9..751ca65 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -316,6 +316,8 @@ class CosyVoice2Model: import tensorrt as trt with open(flow_decoder_estimator_model, 'rb') as f: self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + if self.flow.decoder.estimator_engine is None: + raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model)) self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() self.flow.decoder.fp16 = True diff --git a/cosyvoice/utils/class_utils.py b/cosyvoice/utils/class_utils.py index b8cc471..e81cef0 100644 --- a/cosyvoice/utils/class_utils.py +++ b/cosyvoice/utils/class_utils.py @@ -32,6 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention, RelPositionMultiHeadedAttention) from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling +from cosyvoice.llm.llm import TransformerLM, Qwen2LM +from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec +from cosyvoice.hifigan.generator import HiFTGenerator +from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model COSYVOICE_ACTIVATION_CLASSES = { @@ -68,3 +72,11 @@ COSYVOICE_ATTENTION_CLASSES = { "selfattn": MultiHeadedAttention, "rel_selfattn": RelPositionMultiHeadedAttention, } + + +def get_model_type(configs): + if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): + return CosyVoiceModel + if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): + return CosyVoice2Model + raise TypeError('No valid model type found!') diff --git a/webui.py b/webui.py index 403b016..6c1f588 100644 --- a/webui.py +++ b/webui.py @@ -69,7 +69,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro prompt_wav = None # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode if mode_checkbox_group in ['自然语言控制']: - if cosyvoice.frontend.instruct is False: + if cosyvoice.instruct is False: gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir)) yield (cosyvoice.sample_rate, default_data) if instruct_text == '': @@ -79,7 +79,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略') # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language if mode_checkbox_group in ['跨语种复刻']: - if cosyvoice.frontend.instruct is True: + if cosyvoice.instruct is True: gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir)) yield (cosyvoice.sample_rate, default_data) if instruct_text != '':