diff --git a/cosyvoice/bin/export_trt.py b/cosyvoice/bin/export_trt.py index e6d480c..c737373 100644 --- a/cosyvoice/bin/export_trt.py +++ b/cosyvoice/bin/export_trt.py @@ -1,8 +1,126 @@ -# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。 -# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择 +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys + +logging.getLogger('matplotlib').setLevel(logging.WARNING) + try: import tensorrt except ImportError: - print('step1, 下载\n step2. 解压,安装whl,') -# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令 -# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx \ No newline at end of file + error_msg_zh = [ + "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x", + "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl", + "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${TensorRT-Path}/lib/" + ] + print("\n".join(error_msg_zh)) + sys.exit(1) + +import torch +from cosyvoice.cli.cosyvoice import CosyVoice + +def get_args(): + parser = argparse.ArgumentParser(description='Export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='pretrained_models/CosyVoice-300M-SFT', + help='Local path to the model directory') + + parser.add_argument('--export_half', + action='store_true', + help='Export with half precision (FP16)') + + args = parser.parse_args() + print(args) + return args + +def main(): + args = get_args() + + cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False) + estimator = cosyvoice.model.flow.decoder.estimator + + dtype = torch.float32 if not args.export_half else torch.float16 + device = torch.device("cuda") + batch_size = 1 + seq_len = 256 + hidden_size = cosyvoice.model.flow.output_size + x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device) + mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) + t = torch.rand((batch_size, ), dtype=dtype, device=device) + spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device) + cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) + + onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx' + onnx_file_path = os.path.join(args.model_dir, onnx_file_name) + dummy_input = (x, mask, mu, t, spks, cond) + + estimator = estimator.to(dtype) + + torch.onnx.export( + estimator, + dummy_input, + onnx_file_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'estimator_out': {2: 'seq_len'}, + } + ) + + tensorrt_path = os.environ.get('tensorrt_root_dir') + if not tensorrt_path: + raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.") + + if not os.path.isdir(tensorrt_path): + raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.") + + trt_lib_path = os.path.join(tensorrt_path, "lib") + if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''): + print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.") + os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}" + + trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan' + trt_file_path = os.path.join(args.model_dir, trt_file_name) + + trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec') + trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_path} " \ + "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \ + "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \ + ("--fp16" if args.export_half else "") + + print("execute ", trtexec_cmd) + + os.system(trtexec_cmd) + + # print("x.shape", x.shape) + # print("mask.shape", mask.shape) + # print("mu.shape", mu.shape) + # print("t.shape", t.shape) + # print("spks.shape", spks.shape) + # print("cond.shape", cond.shape) + +if __name__ == "__main__": + main() diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 49fe15f..87f5f3b 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging class CosyVoice: - def __init__(self, model_dir, load_jit=True): + def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False): instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): @@ -39,9 +39,13 @@ class CosyVoice: self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) + if load_jit: self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), '{}/llm.llm.fp16.zip'.format(model_dir)) + if load_trt: + self.model.load_trt(model_dir, use_fp16) + del configs def list_avaliable_spks(self): diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 99ccbe5..747b0ce 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import numpy as np import threading @@ -19,7 +20,6 @@ from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out - class CosyVoiceModel: def __init__(self, @@ -66,6 +66,22 @@ class CosyVoiceModel: llm_llm = torch.jit.load(llm_llm_model) self.llm.llm = llm_llm + def load_trt(self, model_dir, use_fp16): + import tensorrt as trt + trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan' + trt_file_path = os.path.join(model_dir, trt_file_name) + if not os.path.isfile(trt_file_path): + raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file" + + trt.init_libnvinfer_plugins(None, "") + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + with open(trt_file_path, 'rb') as f: + serialized_engine = f.read() + engine = runtime.deserialize_cuda_engine(serialized_engine) + self.flow.decoder.estimator_context = engine.create_execution_context() + self.flow.decoder.estimator = None + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: for i in self.llm.inference(text=text.to(self.device), diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 4349279..be063d3 100755 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -159,7 +159,7 @@ class ConditionalDecoder(nn.Module): _type_: _description_ """ - t = self.time_embeddings(t) + t = self.time_embeddings(t).to(t.dtype) t = self.time_mlp(t) x = pack([x, mu], "b * t")[0] diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 8cbf013..10a0bf3 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -113,7 +113,7 @@ class MaskedDiffWithXvec(torch.nn.Module): # concat text and prompt_text token_len1, token_len2 = prompt_token.shape[1], token.shape[1] token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len - mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding) + mask = (~make_pad_mask(token_len)).to(embedding.dtype).unsqueeze(-1).to(embedding) token = self.input_embedding(torch.clamp(token, min=0)) * mask # text encode diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index f82eaae..cccbf98 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -50,7 +50,7 @@ class ConditionalCFM(BASECFM): shape: (batch_size, n_feats, mel_timesteps) """ z = torch.randn_like(mu) * temperature - t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) @@ -71,6 +71,7 @@ class ConditionalCFM(BASECFM): cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) # I am storing this because I can later plot it by putting a debugger here and saving it to a file # Or in future might add like a return_all_steps flag @@ -96,6 +97,33 @@ class ConditionalCFM(BASECFM): return sol[-1] + def forward_estimator(self, x, mask, mu, t, spks, cond): + + if self.estimator is not None: + return self.estimator.forward(x, mask, mu, t, spks, cond) + else: + assert self.training is False, 'tensorrt cannot be used in training' + bs = x.shape[0] + hs = x.shape[1] + seq_len = x.shape[2] + # assert bs == 1 and hs == 80 + ret = torch.empty_like(x) + self.estimator_context.set_input_shape("x", x.shape) + self.estimator_context.set_input_shape("mask", mask.shape) + self.estimator_context.set_input_shape("mu", mu.shape) + self.estimator_context.set_input_shape("t", t.shape) + self.estimator_context.set_input_shape("spks", spks.shape) + self.estimator_context.set_input_shape("cond", cond.shape) + bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()] + names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out'] + + for i in range(len(bindings)): + self.estimator_context.set_tensor_address(names[i], bindings[i]) + + handle = torch.cuda.current_stream().cuda_stream + self.estimator_context.execute_async_v3(stream_handle=handle) + return ret + def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss