add some instruction and assert

This commit is contained in:
lyuxiang.lx
2024-12-30 16:41:57 +08:00
parent bfcbc73df8
commit b9ddcba5fd
6 changed files with 52 additions and 63 deletions

View File

@@ -20,23 +20,24 @@ import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
class CosyVoice:
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
instruct = True if '-Instruct' in model_dir else False
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
configs = load_hyperpyyaml(f)
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
instruct,
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
@@ -85,8 +86,6 @@ class CosyVoice:
start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
start_time = time.time()
@@ -98,8 +97,8 @@ class CosyVoice:
start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoiceModel)
if self.frontend.instruct is False:
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
if self.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
@@ -112,18 +111,6 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoice2Model)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
start_time = time.time()
@@ -137,18 +124,18 @@ class CosyVoice:
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
instruct = True if '-Instruct' in model_dir else False
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
instruct,
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and load_jit is True:
@@ -168,3 +155,18 @@ class CosyVoice2(CosyVoice):
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
del configs
def inference_instruct(self, *args, **kwargs):
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

View File

@@ -42,7 +42,6 @@ class CosyVoiceFrontEnd:
campplus_model: str,
speech_tokenizer_model: str,
spk2info: str = '',
instruct: bool = False,
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
@@ -58,9 +57,7 @@ class CosyVoiceFrontEnd:
self.spk2info = torch.load(spk2info, map_location=self.device)
else:
self.spk2info = {}
self.instruct = instruct
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
self.use_ttsfrd = use_ttsfrd
if self.use_ttsfrd:
self.frd = ttsfrd.TtsFrontendEngine()
@@ -71,6 +68,7 @@ class CosyVoiceFrontEnd:
else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
def _extract_text_token(self, text):
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
@@ -111,15 +109,11 @@ class CosyVoiceFrontEnd:
if text_frontend is False:
return [text] if split is True else text
text = text.strip()
# When generating text that contains only punctuation marks or whitespace characters
# - Returning empty texts ensures consistent processing logic.
if is_only_punctuation(text):
return []
if contains_chinese(text):
if self.use_ttsfrd:
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else:
if self.use_ttsfrd:
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else:
if contains_chinese(text):
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "")
text = replace_blank(text)
@@ -130,18 +124,13 @@ class CosyVoiceFrontEnd:
text = re.sub(r'[,、]+$', '', text)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
else:
if self.use_ttsfrd:
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else:
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
if split is False:
return text
return texts
texts = [i for i in texts if not is_only_punctuation(i)]
return texts if split is True else text
def frontend_sft(self, tts_text, spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
@@ -188,22 +177,9 @@ class CosyVoiceFrontEnd:
return model_input
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):

View File

@@ -316,6 +316,8 @@ class CosyVoice2Model:
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
if self.flow.decoder.estimator_engine is None:
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
self.flow.decoder.fp16 = True

View File

@@ -32,6 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
from cosyvoice.hifigan.generator import HiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
COSYVOICE_ACTIVATION_CLASSES = {
@@ -68,3 +72,11 @@ COSYVOICE_ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention,
}
def get_model_type(configs):
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoiceModel
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoice2Model
raise TypeError('No valid model type found!')