add some instruction and assert

This commit is contained in:
lyuxiang.lx
2024-12-30 16:41:57 +08:00
parent bfcbc73df8
commit b9ddcba5fd
6 changed files with 52 additions and 63 deletions

View File

@@ -121,13 +121,10 @@ We strongly recommend using `CosyVoice2-0.5B` for better performance.
For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model. For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
For sft inference, please use `CosyVoice-300M-SFT` model. For sft inference, please use `CosyVoice-300M-SFT` model.
For instruct inference, please use `CosyVoice-300M-Instruct` model. For instruct inference, please use `CosyVoice-300M-Instruct` model.
First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
``` sh
export PYTHONPATH=third_party/Matcha-TTS
```
``` python ``` python
import sys
sys.path.append('third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav from cosyvoice.utils.file_utils import load_wav
import torchaudio import torchaudio
@@ -161,7 +158,7 @@ print(cosyvoice.list_available_spks())
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)): for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') # or change to pretrained_models/CosyVoice-300M-25Hz for 25Hz inference
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000) prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):

View File

@@ -20,23 +20,24 @@ import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
class CosyVoice: class CosyVoice:
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True): def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
instruct = True if '-Instruct' in model_dir else False self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir) model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
configs = load_hyperpyyaml(f) configs = load_hyperpyyaml(f)
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'], configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir), '{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v1.onnx'.format(model_dir), '{}/speech_tokenizer_v1.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir), '{}/spk2info.pt'.format(model_dir),
instruct,
configs['allowed_special']) configs['allowed_special'])
self.sample_rate = configs['sample_rate'] self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True): if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
@@ -85,8 +86,6 @@ class CosyVoice:
start_time = time.time() start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate) model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
start_time = time.time() start_time = time.time()
@@ -98,8 +97,8 @@ class CosyVoice:
start_time = time.time() start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True): def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoiceModel) assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
if self.frontend.instruct is False: if self.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir)) raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend) instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
@@ -112,18 +111,6 @@ class CosyVoice:
yield model_output yield model_output
start_time = time.time() start_time = time.time()
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoice2Model)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate) model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
start_time = time.time() start_time = time.time()
@@ -137,18 +124,18 @@ class CosyVoice:
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False): def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
instruct = True if '-Instruct' in model_dir else False self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir) model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'], configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir), '{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v2.onnx'.format(model_dir), '{}/speech_tokenizer_v2.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir), '{}/spk2info.pt'.format(model_dir),
instruct,
configs['allowed_special']) configs['allowed_special'])
self.sample_rate = configs['sample_rate'] self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and load_jit is True: if torch.cuda.is_available() is False and load_jit is True:
@@ -168,3 +155,18 @@ class CosyVoice2(CosyVoice):
if load_trt: if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir)) self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
del configs del configs
def inference_instruct(self, *args, **kwargs):
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

View File

@@ -42,7 +42,6 @@ class CosyVoiceFrontEnd:
campplus_model: str, campplus_model: str,
speech_tokenizer_model: str, speech_tokenizer_model: str,
spk2info: str = '', spk2info: str = '',
instruct: bool = False,
allowed_special: str = 'all'): allowed_special: str = 'all'):
self.tokenizer = get_tokenizer() self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor self.feat_extractor = feat_extractor
@@ -58,9 +57,7 @@ class CosyVoiceFrontEnd:
self.spk2info = torch.load(spk2info, map_location=self.device) self.spk2info = torch.load(spk2info, map_location=self.device)
else: else:
self.spk2info = {} self.spk2info = {}
self.instruct = instruct
self.allowed_special = allowed_special self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
self.use_ttsfrd = use_ttsfrd self.use_ttsfrd = use_ttsfrd
if self.use_ttsfrd: if self.use_ttsfrd:
self.frd = ttsfrd.TtsFrontendEngine() self.frd = ttsfrd.TtsFrontendEngine()
@@ -71,6 +68,7 @@ class CosyVoiceFrontEnd:
else: else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.en_tn_model = EnNormalizer() self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
def _extract_text_token(self, text): def _extract_text_token(self, text):
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
@@ -111,15 +109,11 @@ class CosyVoiceFrontEnd:
if text_frontend is False: if text_frontend is False:
return [text] if split is True else text return [text] if split is True else text
text = text.strip() text = text.strip()
# When generating text that contains only punctuation marks or whitespace characters if self.use_ttsfrd:
# - Returning empty texts ensures consistent processing logic. texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
if is_only_punctuation(text): text = ''.join(texts)
return [] else:
if contains_chinese(text): if contains_chinese(text):
if self.use_ttsfrd:
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else:
text = self.zh_tn_model.normalize(text) text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "") text = text.replace("\n", "")
text = replace_blank(text) text = replace_blank(text)
@@ -130,18 +124,13 @@ class CosyVoiceFrontEnd:
text = re.sub(r'[,、]+$', '', text) text = re.sub(r'[,、]+$', '', text)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False)) token_min_n=60, merge_len=20, comma_split=False))
else:
if self.use_ttsfrd:
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else: else:
text = self.en_tn_model.normalize(text) text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser) text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False)) token_min_n=60, merge_len=20, comma_split=False))
if split is False: texts = [i for i in texts if not is_only_punctuation(i)]
return text return texts if split is True else text
return texts
def frontend_sft(self, tts_text, spk_id): def frontend_sft(self, tts_text, spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
@@ -188,22 +177,9 @@ class CosyVoiceFrontEnd:
return model_input return model_input
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate): def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>') del model_input['llm_prompt_speech_token']
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) del model_input['llm_prompt_speech_token_len']
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input return model_input
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate): def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):

View File

@@ -316,6 +316,8 @@ class CosyVoice2Model:
import tensorrt as trt import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f: with open(flow_decoder_estimator_model, 'rb') as f:
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
if self.flow.decoder.estimator_engine is None:
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
self.flow.decoder.fp16 = True self.flow.decoder.fp16 = True

View File

@@ -32,6 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention) RelPositionMultiHeadedAttention)
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
from cosyvoice.hifigan.generator import HiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
COSYVOICE_ACTIVATION_CLASSES = { COSYVOICE_ACTIVATION_CLASSES = {
@@ -68,3 +72,11 @@ COSYVOICE_ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention, "selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention, "rel_selfattn": RelPositionMultiHeadedAttention,
} }
def get_model_type(configs):
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoiceModel
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoice2Model
raise TypeError('No valid model type found!')

View File

@@ -69,7 +69,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
prompt_wav = None prompt_wav = None
# if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
if mode_checkbox_group in ['自然语言控制']: if mode_checkbox_group in ['自然语言控制']:
if cosyvoice.frontend.instruct is False: if cosyvoice.instruct is False:
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir)) gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
yield (cosyvoice.sample_rate, default_data) yield (cosyvoice.sample_rate, default_data)
if instruct_text == '': if instruct_text == '':
@@ -79,7 +79,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略') gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
if mode_checkbox_group in ['跨语种复刻']: if mode_checkbox_group in ['跨语种复刻']:
if cosyvoice.frontend.instruct is True: if cosyvoice.instruct is True:
gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir)) gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
yield (cosyvoice.sample_rate, default_data) yield (cosyvoice.sample_rate, default_data)
if instruct_text != '': if instruct_text != '':