mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
use automodel
This commit is contained in:
@@ -23,8 +23,10 @@ import torch
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.class_utils import get_model_type
|
||||
|
||||
|
||||
def get_args():
|
||||
@@ -57,15 +59,17 @@ def main():
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
|
||||
try:
|
||||
model = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
model = AutoModel(model_dir=args.model_dir)
|
||||
|
||||
if not isinstance(model, CosyVoice2):
|
||||
if get_model_type(model.model) == CosyVoiceModel:
|
||||
# 1. export flow encoder
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
elif get_model_type(model.model) == CosyVoice2Model:
|
||||
# 1. export llm text_encoder
|
||||
llm_text_encoder = model.model.llm.text_encoder
|
||||
script = get_optimized_script(llm_text_encoder)
|
||||
@@ -90,13 +94,7 @@ def main():
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
else:
|
||||
# 3. export flow encoder
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
raise ValueError('unsupported model type')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -27,7 +27,7 @@ from tqdm import tqdm
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2, CosyVoice3
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
@@ -58,16 +58,7 @@ def main():
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
try:
|
||||
model = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
model = CosyVoice3(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
model = AutoModel(model_dir=args.model_dir)
|
||||
|
||||
# 1. export flow decoder estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
|
||||
@@ -196,7 +196,7 @@ class CosyVoice2(CosyVoice):
|
||||
|
||||
class CosyVoice3(CosyVoice2):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -215,9 +215,9 @@ class CosyVoice3(CosyVoice2):
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
configs['allowed_special'])
|
||||
self.sample_rate = configs['sample_rate']
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
|
||||
load_trt, fp16 = False, False
|
||||
logging.warning('no cuda device, set load_trt/fp16 to False')
|
||||
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
@@ -225,8 +225,23 @@ class CosyVoice3(CosyVoice2):
|
||||
if load_vllm:
|
||||
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||
if load_trt:
|
||||
if self.fp16 is True:
|
||||
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
|
||||
def AutoModel(**kwargs):
|
||||
if not os.path.exists(kwargs['model_dir']):
|
||||
kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
|
||||
if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice(**kwargs)
|
||||
elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice2(**kwargs)
|
||||
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice3(**kwargs)
|
||||
else:
|
||||
raise TypeError('No valid model type found!')
|
||||
|
||||
@@ -122,6 +122,9 @@ class CosyVoiceFrontEnd:
|
||||
return speech_feat, speech_feat_len
|
||||
|
||||
def text_normalize(self, text, split=True, text_frontend=True):
|
||||
# NOTE skip text_frontend when ssml symbol in text
|
||||
if '<|' in text and '|>' in text:
|
||||
text_frontend = False
|
||||
if isinstance(text, Generator):
|
||||
logging.info('get tts_text generator, will skip text_normalize!')
|
||||
return [text]
|
||||
|
||||
@@ -92,29 +92,14 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
def export_cosyvoice2_vllm(model, model_path, device):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
vocab_size = model.speech_embedding.num_embeddings
|
||||
feature_size = model.speech_embedding.embedding_dim
|
||||
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
dtype = torch.bfloat16
|
||||
# lm_head
|
||||
use_bias = True if model.llm_decoder.bias is not None else False
|
||||
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=use_bias)
|
||||
with torch.no_grad():
|
||||
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
|
||||
new_lm_head.weight[vocab_size:] = 0
|
||||
if use_bias is True:
|
||||
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
|
||||
new_lm_head.bias[vocab_size:] = 0
|
||||
model.llm.model.lm_head = new_lm_head
|
||||
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||
model.llm.model.lm_head = model.llm_decoder
|
||||
# embed_tokens
|
||||
embed_tokens = model.llm.model.model.embed_tokens
|
||||
with torch.no_grad():
|
||||
new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
|
||||
new_codec_embed.weight[vocab_size:] = 0
|
||||
model.llm.model.set_input_embeddings(new_codec_embed)
|
||||
model.llm.model.set_input_embeddings(model.speech_embedding)
|
||||
model.llm.model.to(device)
|
||||
model.llm.model.to(dtype)
|
||||
tmp_vocab_size = model.llm.model.config.vocab_size
|
||||
@@ -122,14 +107,12 @@ def export_cosyvoice2_vllm(model, model_path, device):
|
||||
del model.llm.model.generation_config.eos_token_id
|
||||
del model.llm.model.config.bos_token_id
|
||||
del model.llm.model.config.eos_token_id
|
||||
model.llm.model.config.vocab_size = pad_vocab_size
|
||||
model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
|
||||
model.llm.model.config.tie_word_embeddings = False
|
||||
model.llm.model.config.use_bias = use_bias
|
||||
model.llm.model.save_pretrained(model_path)
|
||||
if use_bias is True:
|
||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
else:
|
||||
os.system('sed -i s@Qwen2ForCausalLM@Qwen2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
model.llm.model.set_input_embeddings(embed_tokens)
|
||||
|
||||
Reference in New Issue
Block a user