From 5f21aef7860a3a7af140bc93f294e155ac7b5bcb Mon Sep 17 00:00:00 2001 From: "zhoubofan.zbf" Date: Thu, 29 Aug 2024 23:35:07 +0800 Subject: [PATCH 01/11] add flow decoder tensorrt infer --- cosyvoice/bin/export_trt.py | 105 ++++++++++++++++++++++++++++++-- cosyvoice/cli/cosyvoice.py | 5 +- cosyvoice/cli/model.py | 23 +++++-- cosyvoice/flow/decoder.py | 2 +- cosyvoice/flow/flow_matching.py | 33 ++++++++-- 5 files changed, 149 insertions(+), 19 deletions(-) diff --git a/cosyvoice/bin/export_trt.py b/cosyvoice/bin/export_trt.py index e6d480c..fea8205 100644 --- a/cosyvoice/bin/export_trt.py +++ b/cosyvoice/bin/export_trt.py @@ -1,8 +1,103 @@ -# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。 -# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择 +import argparse +import logging +import os +import sys + +logging.getLogger('matplotlib').setLevel(logging.WARNING) + try: import tensorrt except ImportError: - print('step1, 下载\n step2. 解压,安装whl,') -# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令 -# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx \ No newline at end of file + error_msg_zh = [ + "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x", + "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl", + "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=${TensorRT-Path}/lib/" + ] + print("\n".join(error_msg_zh)) + sys.exit(1) + +import torch +from cosyvoice.cli.cosyvoice import CosyVoice + +def get_args(): + parser = argparse.ArgumentParser(description='Export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='pretrained_models/CosyVoice-300M', + help='Local path to the model directory') + + parser.add_argument('--export_half', + action='store_true', + help='Export with half precision (FP16)') + + args = parser.parse_args() + print(args) + return args + +def main(): + args = get_args() + + cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False) + + flow = cosyvoice.model.flow + estimator = cosyvoice.model.flow.decoder.estimator + + dtype = torch.float32 if not args.export_half else torch.float16 + device = torch.device("cuda") + batch_size = 1 + seq_len = 1024 + hidden_size = flow.output_size + x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) + mask = torch.zeros((batch_size, 1, seq_len), dtype=dtype, device=device) + mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) + t = torch.tensor([0.], dtype=dtype, device=device) + spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device) + cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) + + onnx_file_name = 'estimator_fp16.onnx' if args.export_half else 'estimator_fp32.onnx' + onnx_file_path = os.path.join(args.model_dir, onnx_file_name) + dummy_input = (x, mask, mu, t, spks, cond) + + estimator = estimator.to(dtype) + + torch.onnx.export( + estimator, + dummy_input, + onnx_file_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['output'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'output': {2: 'seq_len'}, + } + ) + + tensorrt_path = os.environ.get('tensorrt_root_dir') + if not tensorrt_path: + raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.") + + if not os.path.isdir(tensorrt_path): + raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.") + + trt_lib_path = os.path.join(tensorrt_path, "lib") + if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''): + print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.") + os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}" + + trt_file_name = 'estimator_fp16.plan' if args.export_half else 'estimator_fp32.plan' + trt_file_path = os.path.join(args.model_dir, trt_file_name) + + trtexec_cmd = f"{tensorrt_path}/bin/trtexec --onnx={onnx_file_path} --saveEngine={trt_file_path} " \ + "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \ + "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose" + + os.system(trtexec_cmd) + +if __name__ == "__main__": + main() diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index d5fbd4e..eac8922 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging class CosyVoice: - def __init__(self, model_dir, load_jit=True, load_trt=True): + def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False): instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): @@ -43,8 +43,7 @@ class CosyVoice: self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), '{}/llm.llm.fp16.zip'.format(model_dir)) if load_trt: - # TODO - self.model.load_trt() + self.model.load_trt(model_dir, use_fp16) del configs def list_avaliable_spks(self): diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 1184f0d..f074fd8 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import numpy as np import threading @@ -19,6 +20,10 @@ from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out +try: + import tensorrt as trt +except ImportError: + ... class CosyVoiceModel: @@ -66,10 +71,20 @@ class CosyVoiceModel: llm_llm = torch.jit.load(llm_llm_model) self.llm.llm = llm_llm - def load_trt(self): - # TODO 你需要的TRT推理的准备 - self.flow.decoder.estimator = xxx - self.flow.decoder.session = xxx + def load_trt(self, model_dir, use_fp16): + trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan' + trt_file_path = os.path.join(model_dir, trt_file_name) + if not os.path.isfile(trt_file_path): + raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file" + + trt.init_libnvinfer_plugins(None, "") + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + with open(trt_file_path, 'rb') as f: + serialized_engine = f.read() + engine = runtime.deserialize_cuda_engine(serialized_engine) + self.flow.decoder.estimator_context = engine.create_execution_context() + self.flow.decoder.estimator_engine = engine def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 4349279..be063d3 100755 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -159,7 +159,7 @@ class ConditionalDecoder(nn.Module): _type_: _description_ """ - t = self.time_embeddings(t) + t = self.time_embeddings(t).to(t.dtype) t = self.time_mlp(t) x = pack([x, mu], "b * t")[0] diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index bcbaeb5..8cab545 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -30,6 +30,9 @@ class ConditionalCFM(BASECFM): # Just change the architecture of the estimator here self.estimator = estimator + self.estimator_context = None + self.estimator_engine = None + @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): """Forward diffusion @@ -50,7 +53,7 @@ class ConditionalCFM(BASECFM): shape: (batch_size, n_feats, mel_timesteps) """ z = torch.randn_like(mu) * temperature - t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) @@ -71,6 +74,7 @@ class ConditionalCFM(BASECFM): cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) # I am storing this because I can later plot it by putting a debugger here and saving it to a file # Or in future might add like a return_all_steps flag @@ -96,13 +100,30 @@ class ConditionalCFM(BASECFM): return sol[-1] - # TODO - def forward_estimator(self): - if isinstance(self.estimator, trt): + def forward_estimator(self, x, mask, mu, t, spks, cond): + if self.estimator_context is not None: assert self.training is False, 'tensorrt cannot be used in training' - return xxx + bs = x.shape[0] + hs = x.shape[1] + seq_len = x.shape[2] + # assert bs == 1 and hs == 80 + ret = torch.empty_like(x) + self.estimator_context.set_input_shape("x", x.shape) + self.estimator_context.set_input_shape("mask", mask.shape) + self.estimator_context.set_input_shape("mu", mu.shape) + self.estimator_context.set_input_shape("t", t.shape) + self.estimator_context.set_input_shape("spks", spks.shape) + self.estimator_context.set_input_shape("cond", cond.shape) + bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()] + + for i in range(len(bindings)): + self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i]) + + handle = torch.cuda.current_stream().cuda_stream + self.estimator_context.execute_async_v3(stream_handle=handle) + return ret else: - return self.estimator.forward + return self.estimator.forward(x, mask, mu, t, spks, cond) def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss From 53a3c1b17fda78da56be12ac9bec3739467f90db Mon Sep 17 00:00:00 2001 From: "zhoubofan.zbf" Date: Fri, 30 Aug 2024 00:47:40 +0800 Subject: [PATCH 02/11] update --- cosyvoice/bin/export_trt.py | 26 +++++++++++++++--------- cosyvoice/cli/cosyvoice.py | 5 ++++- cosyvoice/flow/flow.py | 2 +- cosyvoice/flow/flow_matching.py | 36 +++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 11 deletions(-) diff --git a/cosyvoice/bin/export_trt.py b/cosyvoice/bin/export_trt.py index fea8205..0de874a 100644 --- a/cosyvoice/bin/export_trt.py +++ b/cosyvoice/bin/export_trt.py @@ -38,23 +38,21 @@ def main(): args = get_args() cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False) - - flow = cosyvoice.model.flow estimator = cosyvoice.model.flow.decoder.estimator dtype = torch.float32 if not args.export_half else torch.float16 device = torch.device("cuda") batch_size = 1 - seq_len = 1024 - hidden_size = flow.output_size + seq_len = 256 + hidden_size = cosyvoice.model.flow.output_size x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) - mask = torch.zeros((batch_size, 1, seq_len), dtype=dtype, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device) mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) - t = torch.tensor([0.], dtype=dtype, device=device) + t = torch.rand((batch_size, ), dtype=dtype, device=device) spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device) cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) - onnx_file_name = 'estimator_fp16.onnx' if args.export_half else 'estimator_fp32.onnx' + onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx' onnx_file_path = os.path.join(args.model_dir, onnx_file_name) dummy_input = (x, mask, mu, t, spks, cond) @@ -90,14 +88,24 @@ def main(): print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.") os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}" - trt_file_name = 'estimator_fp16.plan' if args.export_half else 'estimator_fp32.plan' + trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan' trt_file_path = os.path.join(args.model_dir, trt_file_name) trtexec_cmd = f"{tensorrt_path}/bin/trtexec --onnx={onnx_file_path} --saveEngine={trt_file_path} " \ "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \ - "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose" + "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \ + ("--fp16" if args.export_half else "") +# /ossfs/workspace/TensorRT-10.2.0.19/bin/trtexec --onnx=estimator_fp32.onnx --saveEngine=estimator_fp32.plan --minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 --maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose + print("execute ", trtexec_cmd) os.system(trtexec_cmd) + print("x.shape", x.shape) + print("mask.shape", mask.shape) + print("mu.shape", mu.shape) + print("t.shape", t.shape) + print("spks.shape", spks.shape) + print("cond.shape", cond.shape) + if __name__ == "__main__": main() diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index eac8922..3558374 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging class CosyVoice: - def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False): + def __init__(self, model_dir, load_jit=True, load_trt=False, use_fp16=False): instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): @@ -39,11 +39,14 @@ class CosyVoice: self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) + load_jit = False if load_jit: self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), '{}/llm.llm.fp16.zip'.format(model_dir)) + if load_trt: self.model.load_trt(model_dir, use_fp16) + del configs def list_avaliable_spks(self): diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 5466542..4a40eed 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -107,7 +107,7 @@ class MaskedDiffWithXvec(torch.nn.Module): # concat text and prompt_text token_len1, token_len2 = prompt_token.shape[1], token.shape[1] token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len - mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding) + mask = (~make_pad_mask(token_len)).to(embedding.dtype).unsqueeze(-1).to(embedding) token = self.input_embedding(torch.clamp(token, min=0)) * mask # text encode diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 8cab545..ef024dc 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -32,6 +32,7 @@ class ConditionalCFM(BASECFM): self.estimator_context = None self.estimator_engine = None + self.is_saved = None @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): @@ -123,6 +124,41 @@ class ConditionalCFM(BASECFM): self.estimator_context.execute_async_v3(stream_handle=handle) return ret else: + + if self.is_saved == None: + self.is_saved = True + output = self.estimator.forward(x, mask, mu, t, spks, cond) + torch.save(x, "x.pt") + torch.save(mask, "mask.pt") + torch.save(mu, "mu.pt") + torch.save(t, "t.pt") + torch.save(spks, "spks.pt") + torch.save(cond, "cond.pt") + torch.save(output, "output.pt") + dummy_input = (x, mask, mu, t, spks, cond) + torch.onnx.export( + self.estimator, + dummy_input, + "estimator_fp32.onnx", + export_params=True, + opset_version=17, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['output'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'output': {2: 'seq_len'}, + } + ) + # print("x, x.shape", x, x.shape) + # print("mask, mask.shape", mask, mask.shape) + # print("mu, mu.shape", mu, mu.shape) + # print("t, t.shape", t, t.shape) + # print("spks, spks.shape", spks, spks.shape) + # print("cond, cond.shape", cond, cond.shape) return self.estimator.forward(x, mask, mu, t, spks, cond) def compute_loss(self, x1, mask, mu, spks=None, cond=None): From 6e7f5b922a7ddca2d9537629fa624b997443970d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A6=BE=E6=81=AF?= Date: Fri, 30 Aug 2024 13:14:44 +0800 Subject: [PATCH 03/11] update --- cosyvoice/bin/export_trt.py | 6 ++-- cosyvoice/cli/model.py | 3 +- cosyvoice/flow/flow_matching.py | 52 +++++---------------------------- 3 files changed, 12 insertions(+), 49 deletions(-) diff --git a/cosyvoice/bin/export_trt.py b/cosyvoice/bin/export_trt.py index 0de874a..769bdca 100644 --- a/cosyvoice/bin/export_trt.py +++ b/cosyvoice/bin/export_trt.py @@ -66,13 +66,13 @@ def main(): opset_version=18, do_constant_folding=True, input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], - output_names=['output'], + output_names=['estimator_out'], dynamic_axes={ 'x': {2: 'seq_len'}, 'mask': {2: 'seq_len'}, 'mu': {2: 'seq_len'}, 'cond': {2: 'seq_len'}, - 'output': {2: 'seq_len'}, + 'estimator_out': {2: 'seq_len'}, } ) @@ -95,7 +95,7 @@ def main(): "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \ "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \ ("--fp16" if args.export_half else "") -# /ossfs/workspace/TensorRT-10.2.0.19/bin/trtexec --onnx=estimator_fp32.onnx --saveEngine=estimator_fp32.plan --minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 --maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose + print("execute ", trtexec_cmd) os.system(trtexec_cmd) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index f074fd8..59006a7 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -83,8 +83,7 @@ class CosyVoiceModel: with open(trt_file_path, 'rb') as f: serialized_engine = f.read() engine = runtime.deserialize_cuda_engine(serialized_engine) - self.flow.decoder.estimator_context = engine.create_execution_context() - self.flow.decoder.estimator_engine = engine + self.flow.decoder.estimator = engine.create_execution_context() def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index ef024dc..8908d5f 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -30,10 +30,6 @@ class ConditionalCFM(BASECFM): # Just change the architecture of the estimator here self.estimator = estimator - self.estimator_context = None - self.estimator_engine = None - self.is_saved = None - @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): """Forward diffusion @@ -102,7 +98,11 @@ class ConditionalCFM(BASECFM): return sol[-1] def forward_estimator(self, x, mask, mu, t, spks, cond): - if self.estimator_context is not None: + + if not isinstance(self.estimator, torch.nn.Module): + return self.estimator.forward(x, mask, mu, t, spks, cond) + + else: assert self.training is False, 'tensorrt cannot be used in training' bs = x.shape[0] hs = x.shape[1] @@ -116,50 +116,14 @@ class ConditionalCFM(BASECFM): self.estimator_context.set_input_shape("spks", spks.shape) self.estimator_context.set_input_shape("cond", cond.shape) bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()] + names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out'] for i in range(len(bindings)): - self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i]) + self.estimator.set_tensor_address(names[i], bindings[i]) handle = torch.cuda.current_stream().cuda_stream - self.estimator_context.execute_async_v3(stream_handle=handle) + self.estimator.execute_async_v3(stream_handle=handle) return ret - else: - - if self.is_saved == None: - self.is_saved = True - output = self.estimator.forward(x, mask, mu, t, spks, cond) - torch.save(x, "x.pt") - torch.save(mask, "mask.pt") - torch.save(mu, "mu.pt") - torch.save(t, "t.pt") - torch.save(spks, "spks.pt") - torch.save(cond, "cond.pt") - torch.save(output, "output.pt") - dummy_input = (x, mask, mu, t, spks, cond) - torch.onnx.export( - self.estimator, - dummy_input, - "estimator_fp32.onnx", - export_params=True, - opset_version=17, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], - output_names=['output'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'output': {2: 'seq_len'}, - } - ) - # print("x, x.shape", x, x.shape) - # print("mask, mask.shape", mask, mask.shape) - # print("mu, mu.shape", mu, mu.shape) - # print("t, t.shape", t, t.shape) - # print("spks, spks.shape", spks, spks.shape) - # print("cond, cond.shape", cond, cond.shape) - return self.estimator.forward(x, mask, mu, t, spks, cond) def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss From 29408360fb4484b6ae53810b181ceec9b79611d9 Mon Sep 17 00:00:00 2001 From: "zhoubofan.zbf" Date: Fri, 30 Aug 2024 13:43:54 +0800 Subject: [PATCH 04/11] fix bug --- cosyvoice/bin/export_trt.py | 19 ++++++++++--------- cosyvoice/cli/cosyvoice.py | 4 ++-- cosyvoice/cli/model.py | 3 ++- cosyvoice/flow/flow_matching.py | 8 ++++---- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/cosyvoice/bin/export_trt.py b/cosyvoice/bin/export_trt.py index 769bdca..1bf958c 100644 --- a/cosyvoice/bin/export_trt.py +++ b/cosyvoice/bin/export_trt.py @@ -11,7 +11,7 @@ except ImportError: error_msg_zh = [ "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x", "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl", - "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=${TensorRT-Path}/lib/" + "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${TensorRT-Path}/lib/" ] print("\n".join(error_msg_zh)) sys.exit(1) @@ -23,7 +23,7 @@ def get_args(): parser = argparse.ArgumentParser(description='Export your model for deployment') parser.add_argument('--model_dir', type=str, - default='pretrained_models/CosyVoice-300M', + default='pretrained_models/CosyVoice-300M-SFT', help='Local path to the model directory') parser.add_argument('--export_half', @@ -91,7 +91,8 @@ def main(): trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan' trt_file_path = os.path.join(args.model_dir, trt_file_name) - trtexec_cmd = f"{tensorrt_path}/bin/trtexec --onnx={onnx_file_path} --saveEngine={trt_file_path} " \ + trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec') + trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_path} " \ "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \ "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \ ("--fp16" if args.export_half else "") @@ -100,12 +101,12 @@ def main(): os.system(trtexec_cmd) - print("x.shape", x.shape) - print("mask.shape", mask.shape) - print("mu.shape", mu.shape) - print("t.shape", t.shape) - print("spks.shape", spks.shape) - print("cond.shape", cond.shape) + # print("x.shape", x.shape) + # print("mask.shape", mask.shape) + # print("mu.shape", mu.shape) + # print("t.shape", t.shape) + # print("spks.shape", spks.shape) + # print("cond.shape", cond.shape) if __name__ == "__main__": main() diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 3558374..5028ad1 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging class CosyVoice: - def __init__(self, model_dir, load_jit=True, load_trt=False, use_fp16=False): + def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False): instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): @@ -39,7 +39,7 @@ class CosyVoice: self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) - load_jit = False + if load_jit: self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), '{}/llm.llm.fp16.zip'.format(model_dir)) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 59006a7..bf18ea3 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -83,7 +83,8 @@ class CosyVoiceModel: with open(trt_file_path, 'rb') as f: serialized_engine = f.read() engine = runtime.deserialize_cuda_engine(serialized_engine) - self.flow.decoder.estimator = engine.create_execution_context() + self.flow.decoder.estimator_context = engine.create_execution_context() + self.flow.decoder.estimator = None def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 8908d5f..18efe75 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -99,10 +99,10 @@ class ConditionalCFM(BASECFM): def forward_estimator(self, x, mask, mu, t, spks, cond): - if not isinstance(self.estimator, torch.nn.Module): + if self.estimator is not None: return self.estimator.forward(x, mask, mu, t, spks, cond) - else: + print("-----------") assert self.training is False, 'tensorrt cannot be used in training' bs = x.shape[0] hs = x.shape[1] @@ -119,10 +119,10 @@ class ConditionalCFM(BASECFM): names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out'] for i in range(len(bindings)): - self.estimator.set_tensor_address(names[i], bindings[i]) + self.estimator_context.set_tensor_address(names[i], bindings[i]) handle = torch.cuda.current_stream().cuda_stream - self.estimator.execute_async_v3(stream_handle=handle) + self.estimator_context.execute_async_v3(stream_handle=handle) return ret def compute_loss(self, x1, mask, mu, spks=None, cond=None): From 18599be8d5edba918fcaac9e0f55ecd6ab4e045c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A6=BE=E6=81=AF?= Date: Fri, 30 Aug 2024 14:15:24 +0800 Subject: [PATCH 05/11] mirror modify --- cosyvoice/bin/export_trt.py | 14 ++++++++++++++ cosyvoice/cli/model.py | 6 +----- cosyvoice/flow/flow_matching.py | 1 - 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/cosyvoice/bin/export_trt.py b/cosyvoice/bin/export_trt.py index 1bf958c..c737373 100644 --- a/cosyvoice/bin/export_trt.py +++ b/cosyvoice/bin/export_trt.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import logging import os diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index bf18ea3..50ae0b1 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -20,11 +20,6 @@ from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out -try: - import tensorrt as trt -except ImportError: - ... - class CosyVoiceModel: def __init__(self, @@ -72,6 +67,7 @@ class CosyVoiceModel: self.llm.llm = llm_llm def load_trt(self, model_dir, use_fp16): + import tensorrt as trt trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan' trt_file_path = os.path.join(model_dir, trt_file_name) if not os.path.isfile(trt_file_path): diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 18efe75..a31506a 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -102,7 +102,6 @@ class ConditionalCFM(BASECFM): if self.estimator is not None: return self.estimator.forward(x, mask, mu, t, spks, cond) else: - print("-----------") assert self.training is False, 'tensorrt cannot be used in training' bs = x.shape[0] hs = x.shape[1] From fadb22086f44c7c5c0598b5ddd58c935ee5958c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A6=BE=E6=81=AF?= Date: Tue, 3 Sep 2024 11:06:24 +0800 Subject: [PATCH 06/11] export onnx --- cosyvoice/bin/export_onnx.py | 228 ++++++++++++++++++++++++++++++++ cosyvoice/bin/export_trt.py | 126 ------------------ cosyvoice/cli/cosyvoice.py | 11 +- cosyvoice/cli/model.py | 52 ++++++-- cosyvoice/flow/flow_matching.py | 65 ++++++--- 5 files changed, 318 insertions(+), 164 deletions(-) create mode 100644 cosyvoice/bin/export_onnx.py delete mode 100644 cosyvoice/bin/export_trt.py diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py new file mode 100644 index 0000000..6ef4ab1 --- /dev/null +++ b/cosyvoice/bin/export_onnx.py @@ -0,0 +1,228 @@ +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys + +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import onnxruntime as ort +import numpy as np + +# try: +# import tensorrt +# import tensorrt as trt +# except ImportError: +# error_msg_zh = [ +# "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x", +# "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl", +# "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${TensorRT-Path}/lib/" +# ] +# print("\n".join(error_msg_zh)) +# sys.exit(1) + +import torch +from cosyvoice.cli.cosyvoice import CosyVoice + + +def calculate_onnx(onnx_file, x, mask, mu, t, spks, cond): + providers = ['CUDAExecutionProvider'] + sess_options = ort.SessionOptions() + + providers = [ + 'CUDAExecutionProvider' + ] + + # Load the ONNX model + session = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers) + + x_np = x.cpu().numpy() + mask_np = mask.cpu().numpy() + mu_np = mu.cpu().numpy() + t_np = np.array(t.cpu()) + spks_np = spks.cpu().numpy() + cond_np = cond.cpu().numpy() + + ort_inputs = { + 'x': x_np, + 'mask': mask_np, + 'mu': mu_np, + 't': t_np, + 'spks': spks_np, + 'cond': cond_np + } + + output = session.run(None, ort_inputs) + + return output[0] + +# def calculate_tensorrt(trt_file, x, mask, mu, t, spks, cond): +# trt.init_libnvinfer_plugins(None, "") +# logger = trt.Logger(trt.Logger.WARNING) +# runtime = trt.Runtime(logger) +# with open(trt_file, 'rb') as f: +# serialized_engine = f.read() +# engine = runtime.deserialize_cuda_engine(serialized_engine) +# context = engine.create_execution_context() + +# bs = x.shape[0] +# hs = x.shape[1] +# seq_len = x.shape[2] + +# ret = torch.zeros_like(x) + +# # Set input shapes for dynamic dimensions +# context.set_input_shape("x", x.shape) +# context.set_input_shape("mask", mask.shape) +# context.set_input_shape("mu", mu.shape) +# context.set_input_shape("t", t.shape) +# context.set_input_shape("spks", spks.shape) +# context.set_input_shape("cond", cond.shape) + +# # bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()] +# # names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out'] +# # +# # for i in range(len(bindings)): +# # context.set_tensor_address(names[i], bindings[i]) +# # +# # handle = torch.cuda.current_stream().cuda_stream +# # context.execute_async_v3(stream_handle=handle) + +# # Create a list of bindings +# bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())] + +# # Execute the inference +# context.execute_v2(bindings=bindings) + +# torch.cuda.synchronize() + +# return ret + + +# def test_calculate_value(estimator, onnx_file, trt_file, dummy_input, args): +# torch_output = estimator.forward(**dummy_input).cpu().detach().numpy() +# onnx_output = calculate_onnx(onnx_file, **dummy_input) +# tensorrt_output = calculate_tensorrt(trt_file, **dummy_input).cpu().detach().numpy() +# atol = 2e-3 # Absolute tolerance +# rtol = 1e-4 # Relative tolerance + +# print(f"args.export_half: {args.export_half}, args.model_dir: {args.model_dir}") +# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") + +# print("torch_output diff with onnx_output: ", ) +# print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(torch_output, onnx_output, atol, rtol)) +# print(f"max diff value: ", np.max(np.fabs(torch_output - onnx_output))) +# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") + +# print("torch_output diff with tensorrt_output: ") +# print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(torch_output, tensorrt_output, atol, rtol)) +# print(f"max diff value: ", np.max(np.fabs(torch_output - tensorrt_output))) +# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") + +# print("onnx_output diff with tensorrt_output: ") +# print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(onnx_output, tensorrt_output, atol, rtol)) +# print(f"max diff value: ", np.max(np.fabs(onnx_output - tensorrt_output))) +# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") + + +def get_args(): + parser = argparse.ArgumentParser(description='Export your model for deployment') + parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice-300M', help='Local path to the model directory') + parser.add_argument('--export_half', type=str, choices=['True', 'False'], default='False', help='Export with half precision (FP16)') + # parser.add_argument('--trt_max_len', type=int, default=8192, help='Export max len') + parser.add_argument('--exec_export', type=str, choices=['True', 'False'], default='True', help='Exec export') + + args = parser.parse_args() + args.export_half = args.export_half == 'True' + args.exec_export = args.exec_export == 'True' + print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") + print(args) + return args + +def main(): + args = get_args() + + cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False) + estimator = cosyvoice.model.flow.decoder.estimator + + dtype = torch.float32 if not args.export_half else torch.float16 + device = torch.device("cuda") + batch_size = 1 + seq_len = 256 + out_channels = cosyvoice.model.flow.decoder.estimator.out_channels + x = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device) + t = torch.rand((batch_size, ), dtype=dtype, device=device) + spks = torch.rand((batch_size, out_channels), dtype=dtype, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device) + + onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx' + onnx_file_path = os.path.join(args.model_dir, onnx_file_name) + dummy_input = (x, mask, mu, t, spks, cond) + + estimator = estimator.to(dtype) + + if args.exec_export: + torch.onnx.export( + estimator, + dummy_input, + onnx_file_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'estimator_out': {2: 'seq_len'}, + } + ) + + # tensorrt_path = os.environ.get('tensorrt_root_dir') + # if not tensorrt_path: + # raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.") + + # if not os.path.isdir(tensorrt_path): + # raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.") + + # trt_lib_path = os.path.join(tensorrt_path, "lib") + # if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''): + # print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.") + # os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}" + + # trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan' + # trt_file_path = os.path.join(args.model_dir, trt_file_name) + + # trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec') + # trt_max_len = args.trt_max_len + # trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_path} " \ + # f"--minShapes=x:1x{out_channels}x1,mask:1x1x1,mu:1x{out_channels}x1,t:1,spks:1x{out_channels},cond:1x{out_channels}x1 " \ + # f"--maxShapes=x:1x{out_channels}x{trt_max_len},mask:1x1x{trt_max_len},mu:1x{out_channels}x{trt_max_len},t:1,spks:1x{out_channels},cond:1x{out_channels}x{trt_max_len} " + \ + # ("--fp16" if args.export_half else "") + + # print("execute ", trtexec_cmd) + + # if args.exec_export: + # os.system(trtexec_cmd) + + # dummy_input = {'x': x, 'mask': mask, 'mu': mu, 't': t, 'spks': spks, 'cond': cond} + # test_calculate_value(estimator, onnx_file_path, trt_file_path, dummy_input, args) + +if __name__ == "__main__": + main() diff --git a/cosyvoice/bin/export_trt.py b/cosyvoice/bin/export_trt.py deleted file mode 100644 index c737373..0000000 --- a/cosyvoice/bin/export_trt.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import os -import sys - -logging.getLogger('matplotlib').setLevel(logging.WARNING) - -try: - import tensorrt -except ImportError: - error_msg_zh = [ - "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x", - "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl", - "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${TensorRT-Path}/lib/" - ] - print("\n".join(error_msg_zh)) - sys.exit(1) - -import torch -from cosyvoice.cli.cosyvoice import CosyVoice - -def get_args(): - parser = argparse.ArgumentParser(description='Export your model for deployment') - parser.add_argument('--model_dir', - type=str, - default='pretrained_models/CosyVoice-300M-SFT', - help='Local path to the model directory') - - parser.add_argument('--export_half', - action='store_true', - help='Export with half precision (FP16)') - - args = parser.parse_args() - print(args) - return args - -def main(): - args = get_args() - - cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False) - estimator = cosyvoice.model.flow.decoder.estimator - - dtype = torch.float32 if not args.export_half else torch.float16 - device = torch.device("cuda") - batch_size = 1 - seq_len = 256 - hidden_size = cosyvoice.model.flow.output_size - x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) - mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device) - mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) - t = torch.rand((batch_size, ), dtype=dtype, device=device) - spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device) - cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device) - - onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx' - onnx_file_path = os.path.join(args.model_dir, onnx_file_name) - dummy_input = (x, mask, mu, t, spks, cond) - - estimator = estimator.to(dtype) - - torch.onnx.export( - estimator, - dummy_input, - onnx_file_path, - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], - output_names=['estimator_out'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'estimator_out': {2: 'seq_len'}, - } - ) - - tensorrt_path = os.environ.get('tensorrt_root_dir') - if not tensorrt_path: - raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.") - - if not os.path.isdir(tensorrt_path): - raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.") - - trt_lib_path = os.path.join(tensorrt_path, "lib") - if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''): - print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.") - os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}" - - trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan' - trt_file_path = os.path.join(args.model_dir, trt_file_name) - - trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec') - trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_path} " \ - "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \ - "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \ - ("--fp16" if args.export_half else "") - - print("execute ", trtexec_cmd) - - os.system(trtexec_cmd) - - # print("x.shape", x.shape) - # print("mask.shape", mask.shape) - # print("mu.shape", mu.shape) - # print("t.shape", t.shape) - # print("spks.shape", spks.shape) - # print("cond.shape", cond.shape) - -if __name__ == "__main__": - main() diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 5028ad1..cf5e4e5 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging class CosyVoice: - def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False): + def __init__(self, model_dir, load_jit=True, load_trt=False, load_onnx=True, use_fp16=False): instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): @@ -39,13 +39,16 @@ class CosyVoice: self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) - + if load_jit: self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), '{}/llm.llm.fp16.zip'.format(model_dir)) - if load_trt: - self.model.load_trt(model_dir, use_fp16) + # if load_trt: + # self.model.load_trt(model_dir, use_fp16) + + if load_onnx: + self.model.load_onnx(model_dir, use_fp16) del configs diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 50ae0b1..8401d42 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -19,6 +19,13 @@ import time from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out +import numpy as np +import onnxruntime as ort + +# try: +# import tensorrt as trt +# except ImportError: +# ... class CosyVoiceModel: @@ -66,21 +73,40 @@ class CosyVoiceModel: llm_llm = torch.jit.load(llm_llm_model) self.llm.llm = llm_llm - def load_trt(self, model_dir, use_fp16): - import tensorrt as trt - trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan' - trt_file_path = os.path.join(model_dir, trt_file_name) - if not os.path.isfile(trt_file_path): - raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file" + # def load_trt(self, model_dir, use_fp16): + # trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan' + # trt_file_path = os.path.join(model_dir, trt_file_name) + # if not os.path.isfile(trt_file_path): + # raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file" - trt.init_libnvinfer_plugins(None, "") - logger = trt.Logger(trt.Logger.WARNING) - runtime = trt.Runtime(logger) - with open(trt_file_path, 'rb') as f: - serialized_engine = f.read() - engine = runtime.deserialize_cuda_engine(serialized_engine) - self.flow.decoder.estimator_context = engine.create_execution_context() + # trt.init_libnvinfer_plugins(None, "") + # logger = trt.Logger(trt.Logger.WARNING) + # runtime = trt.Runtime(logger) + # with open(trt_file_path, 'rb') as f: + # serialized_engine = f.read() + # engine = runtime.deserialize_cuda_engine(serialized_engine) + # self.flow.decoder.estimator_context = engine.create_execution_context() + # self.flow.decoder.estimator = None + + def load_onnx(self, model_dir, use_fp16): + onnx_file_name = 'estimator_fp16.onnx' if use_fp16 else 'estimator_fp32.onnx' + onnx_file_path = os.path.join(model_dir, onnx_file_name) + if not os.path.isfile(onnx_file_path): + raise f"{onnx_file_path} does not exist. Please use bin/export_trt.py to generate .onnx file" + + providers = ['CUDAExecutionProvider'] + sess_options = ort.SessionOptions() + + # Add TensorRT Execution Provider + providers = [ + 'CUDAExecutionProvider' + ] + + # Load the ONNX model + self.flow.decoder.session = ort.InferenceSession(onnx_file_path, sess_options=sess_options, providers=providers) + # self.flow.decoder.estimator_context = None self.flow.decoder.estimator = None + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index a31506a..27e2276 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -14,6 +14,8 @@ import torch import torch.nn.functional as F from matcha.models.components.flow_matching import BASECFM +import onnxruntime as ort +import numpy as np class ConditionalCFM(BASECFM): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): @@ -29,6 +31,8 @@ class ConditionalCFM(BASECFM): in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) # Just change the architecture of the estimator here self.estimator = estimator + self.estimator_context = None # for tensorrt + self.session = None # for onnx @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): @@ -101,28 +105,47 @@ class ConditionalCFM(BASECFM): if self.estimator is not None: return self.estimator.forward(x, mask, mu, t, spks, cond) - else: - assert self.training is False, 'tensorrt cannot be used in training' - bs = x.shape[0] - hs = x.shape[1] - seq_len = x.shape[2] - # assert bs == 1 and hs == 80 - ret = torch.empty_like(x) - self.estimator_context.set_input_shape("x", x.shape) - self.estimator_context.set_input_shape("mask", mask.shape) - self.estimator_context.set_input_shape("mu", mu.shape) - self.estimator_context.set_input_shape("t", t.shape) - self.estimator_context.set_input_shape("spks", spks.shape) - self.estimator_context.set_input_shape("cond", cond.shape) - bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()] - names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out'] - - for i in range(len(bindings)): - self.estimator_context.set_tensor_address(names[i], bindings[i]) + # elif self.estimator_context is not None: + # assert self.training is False, 'tensorrt cannot be used in training' + # bs = x.shape[0] + # hs = x.shape[1] + # seq_len = x.shape[2] + # # assert bs == 1 and hs == 80 + # ret = torch.empty_like(x) + # self.estimator_context.set_input_shape("x", x.shape) + # self.estimator_context.set_input_shape("mask", mask.shape) + # self.estimator_context.set_input_shape("mu", mu.shape) + # self.estimator_context.set_input_shape("t", t.shape) + # self.estimator_context.set_input_shape("spks", spks.shape) + # self.estimator_context.set_input_shape("cond", cond.shape) + + # # Create a list of bindings + # bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())] + + # # Execute the inference + # self.estimator_context.execute_v2(bindings=bindings) + # return ret + else: + x_np = x.cpu().numpy() + mask_np = mask.cpu().numpy() + mu_np = mu.cpu().numpy() + t_np = t.cpu().numpy() + spks_np = spks.cpu().numpy() + cond_np = cond.cpu().numpy() + + ort_inputs = { + 'x': x_np, + 'mask': mask_np, + 'mu': mu_np, + 't': t_np, + 'spks': spks_np, + 'cond': cond_np + } + + output = self.session.run(None, ort_inputs)[0] + + return torch.tensor(output, dtype=x.dtype, device=x.device) - handle = torch.cuda.current_stream().cuda_stream - self.estimator_context.execute_async_v3(stream_handle=handle) - return ret def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss From a801416805d37de1188b7b6dd4e089be3e4b0f7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A6=BE=E6=81=AF?= Date: Tue, 3 Sep 2024 11:07:47 +0800 Subject: [PATCH 07/11] mirror modify --- cosyvoice/cli/model.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 8401d42..c31f281 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -22,11 +22,6 @@ from cosyvoice.utils.common import fade_in_out import numpy as np import onnxruntime as ort -# try: -# import tensorrt as trt -# except ImportError: -# ... - class CosyVoiceModel: def __init__(self, @@ -74,6 +69,7 @@ class CosyVoiceModel: self.llm.llm = llm_llm # def load_trt(self, model_dir, use_fp16): + # import tensorrt as trt # trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan' # trt_file_path = os.path.join(model_dir, trt_file_name) # if not os.path.isfile(trt_file_path): From 2ce724045b6d11461c1c274bea4e4aff574b50bc Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 4 Sep 2024 18:15:33 +0800 Subject: [PATCH 08/11] add onnx export --- cosyvoice/bin/export_jit.py | 9 +- cosyvoice/bin/export_onnx.py | 257 +++++++++----------------------- cosyvoice/cli/cosyvoice.py | 13 +- cosyvoice/cli/model.py | 52 ++----- cosyvoice/flow/flow_matching.py | 53 ++----- requirements.txt | 1 + 6 files changed, 105 insertions(+), 280 deletions(-) diff --git a/cosyvoice/bin/export_jit.py b/cosyvoice/bin/export_jit.py index 1eceb1d..cbd0f18 100644 --- a/cosyvoice/bin/export_jit.py +++ b/cosyvoice/bin/export_jit.py @@ -44,7 +44,7 @@ def main(): torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) - cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False) + cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False) # 1. export llm text_encoder llm_text_encoder = cosyvoice.model.llm.text_encoder.half() @@ -60,5 +60,12 @@ def main(): script = torch.jit.optimize_for_inference(script) script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) + # 3. export flow encoder + flow_encoder = cosyvoice.model.flow.encoder + script = torch.jit.script(flow_encoder) + script = torch.jit.freeze(script) + script = torch.jit.optimize_for_inference(script) + script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) + if __name__ == '__main__': main() diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py index 6ef4ab1..58b5ab6 100644 --- a/cosyvoice/bin/export_onnx.py +++ b/cosyvoice/bin/export_onnx.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,217 +13,97 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + import argparse import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) import os import sys - -logging.getLogger('matplotlib').setLevel(logging.WARNING) -import onnxruntime as ort -import numpy as np - -# try: -# import tensorrt -# import tensorrt as trt -# except ImportError: -# error_msg_zh = [ -# "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x", -# "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl", -# "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${TensorRT-Path}/lib/" -# ] -# print("\n".join(error_msg_zh)) -# sys.exit(1) - +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append('{}/../..'.format(ROOT_DIR)) +sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) +import onnxruntime +import random import torch +from tqdm import tqdm from cosyvoice.cli.cosyvoice import CosyVoice -def calculate_onnx(onnx_file, x, mask, mu, t, spks, cond): - providers = ['CUDAExecutionProvider'] - sess_options = ort.SessionOptions() - - providers = [ - 'CUDAExecutionProvider' - ] - - # Load the ONNX model - session = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers) - - x_np = x.cpu().numpy() - mask_np = mask.cpu().numpy() - mu_np = mu.cpu().numpy() - t_np = np.array(t.cpu()) - spks_np = spks.cpu().numpy() - cond_np = cond.cpu().numpy() - - ort_inputs = { - 'x': x_np, - 'mask': mask_np, - 'mu': mu_np, - 't': t_np, - 'spks': spks_np, - 'cond': cond_np - } - - output = session.run(None, ort_inputs) - - return output[0] - -# def calculate_tensorrt(trt_file, x, mask, mu, t, spks, cond): -# trt.init_libnvinfer_plugins(None, "") -# logger = trt.Logger(trt.Logger.WARNING) -# runtime = trt.Runtime(logger) -# with open(trt_file, 'rb') as f: -# serialized_engine = f.read() -# engine = runtime.deserialize_cuda_engine(serialized_engine) -# context = engine.create_execution_context() - -# bs = x.shape[0] -# hs = x.shape[1] -# seq_len = x.shape[2] - -# ret = torch.zeros_like(x) - -# # Set input shapes for dynamic dimensions -# context.set_input_shape("x", x.shape) -# context.set_input_shape("mask", mask.shape) -# context.set_input_shape("mu", mu.shape) -# context.set_input_shape("t", t.shape) -# context.set_input_shape("spks", spks.shape) -# context.set_input_shape("cond", cond.shape) - -# # bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()] -# # names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out'] -# # -# # for i in range(len(bindings)): -# # context.set_tensor_address(names[i], bindings[i]) -# # -# # handle = torch.cuda.current_stream().cuda_stream -# # context.execute_async_v3(stream_handle=handle) - -# # Create a list of bindings -# bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())] - -# # Execute the inference -# context.execute_v2(bindings=bindings) - -# torch.cuda.synchronize() - -# return ret - - -# def test_calculate_value(estimator, onnx_file, trt_file, dummy_input, args): -# torch_output = estimator.forward(**dummy_input).cpu().detach().numpy() -# onnx_output = calculate_onnx(onnx_file, **dummy_input) -# tensorrt_output = calculate_tensorrt(trt_file, **dummy_input).cpu().detach().numpy() -# atol = 2e-3 # Absolute tolerance -# rtol = 1e-4 # Relative tolerance - -# print(f"args.export_half: {args.export_half}, args.model_dir: {args.model_dir}") -# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") - -# print("torch_output diff with onnx_output: ", ) -# print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(torch_output, onnx_output, atol, rtol)) -# print(f"max diff value: ", np.max(np.fabs(torch_output - onnx_output))) -# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") - -# print("torch_output diff with tensorrt_output: ") -# print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(torch_output, tensorrt_output, atol, rtol)) -# print(f"max diff value: ", np.max(np.fabs(torch_output - tensorrt_output))) -# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") - -# print("onnx_output diff with tensorrt_output: ") -# print(f"compare with atol: {atol}, rtol: {rtol} ", np.allclose(onnx_output, tensorrt_output, atol, rtol)) -# print(f"max diff value: ", np.max(np.fabs(onnx_output - tensorrt_output))) -# print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") +def get_dummy_input(batch_size, seq_len, out_channels, device): + x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + t = torch.rand((batch_size), dtype=torch.float32, device=device) + spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + return x, mask, mu, t, spks, cond def get_args(): - parser = argparse.ArgumentParser(description='Export your model for deployment') - parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice-300M', help='Local path to the model directory') - parser.add_argument('--export_half', type=str, choices=['True', 'False'], default='False', help='Export with half precision (FP16)') - # parser.add_argument('--trt_max_len', type=int, default=8192, help='Export max len') - parser.add_argument('--exec_export', type=str, choices=['True', 'False'], default='True', help='Exec export') - + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='pretrained_models/CosyVoice-300M', + help='local path') args = parser.parse_args() - args.export_half = args.export_half == 'True' - args.exec_export = args.exec_export == 'True' - print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") print(args) return args def main(): args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') - cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False) + cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False) + + # 1. export flow decoder estimator estimator = cosyvoice.model.flow.decoder.estimator - dtype = torch.float32 if not args.export_half else torch.float16 - device = torch.device("cuda") - batch_size = 1 - seq_len = 256 + device = cosyvoice.model.device + batch_size, seq_len = 1, 256 out_channels = cosyvoice.model.flow.decoder.estimator.out_channels - x = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device) - mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device) - mu = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device) - t = torch.rand((batch_size, ), dtype=dtype, device=device) - spks = torch.rand((batch_size, out_channels), dtype=dtype, device=device) - cond = torch.rand((batch_size, out_channels, seq_len), dtype=dtype, device=device) + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {0: 'batch_size', 2: 'seq_len'}, + 'mask': {0: 'batch_size', 2: 'seq_len'}, + 'mu': {0: 'batch_size', 2: 'seq_len'}, + 'cond': {0: 'batch_size', 2: 'seq_len'}, + 't': {0: 'batch_size'}, + 'spks': {0: 'batch_size'}, + 'estimator_out': {0: 'batch_size', 2: 'seq_len'}, + } + ) - onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx' - onnx_file_path = os.path.join(args.model_dir, onnx_file_name) - dummy_input = (x, mask, mu, t, spks, cond) + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers) - estimator = estimator.to(dtype) - - if args.exec_export: - torch.onnx.export( - estimator, - dummy_input, - onnx_file_path, - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], - output_names=['estimator_out'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'estimator_out': {2: 'seq_len'}, - } - ) - - # tensorrt_path = os.environ.get('tensorrt_root_dir') - # if not tensorrt_path: - # raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.") - - # if not os.path.isdir(tensorrt_path): - # raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.") - - # trt_lib_path = os.path.join(tensorrt_path, "lib") - # if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''): - # print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.") - # os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}" - - # trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan' - # trt_file_path = os.path.join(args.model_dir, trt_file_name) - - # trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec') - # trt_max_len = args.trt_max_len - # trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_path} " \ - # f"--minShapes=x:1x{out_channels}x1,mask:1x1x1,mu:1x{out_channels}x1,t:1,spks:1x{out_channels},cond:1x{out_channels}x1 " \ - # f"--maxShapes=x:1x{out_channels}x{trt_max_len},mask:1x1x{trt_max_len},mu:1x{out_channels}x{trt_max_len},t:1,spks:1x{out_channels},cond:1x{out_channels}x{trt_max_len} " + \ - # ("--fp16" if args.export_half else "") - - # print("execute ", trtexec_cmd) - - # if args.exec_export: - # os.system(trtexec_cmd) - - # dummy_input = {'x': x, 'mask': mask, 'mu': mu, 't': t, 'spks': spks, 'cond': cond} - # test_calculate_value(estimator, onnx_file_path, trt_file_path, dummy_input, args) + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) if __name__ == "__main__": main() diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index cf5e4e5..eab5cad 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging class CosyVoice: - def __init__(self, model_dir, load_jit=True, load_trt=False, load_onnx=True, use_fp16=False): + def __init__(self, model_dir, load_jit=True, load_onnx=True): instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): @@ -39,17 +39,12 @@ class CosyVoice: self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) - if load_jit: self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), - '{}/llm.llm.fp16.zip'.format(model_dir)) - - # if load_trt: - # self.model.load_trt(model_dir, use_fp16) - + '{}/llm.llm.fp16.zip'.format(model_dir), + '{}/flow.encoder.fp32.zip'.format(model_dir)) if load_onnx: - self.model.load_onnx(model_dir, use_fp16) - + self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir)) del configs def list_avaliable_spks(self): diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index a5348d2..a78ded4 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import torch import numpy as np import threading @@ -20,7 +19,6 @@ from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out import numpy as np -import onnxruntime as ort class CosyVoiceModel: @@ -62,47 +60,22 @@ class CosyVoiceModel: self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) self.hift.to(self.device).eval() - def load_jit(self, llm_text_encoder_model, llm_llm_model): + def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): llm_text_encoder = torch.jit.load(llm_text_encoder_model) self.llm.text_encoder = llm_text_encoder llm_llm = torch.jit.load(llm_llm_model) self.llm.llm = llm_llm + flow_encoder = torch.jit.load(flow_encoder_model) + self.flow.encoder = flow_encoder - # def load_trt(self, model_dir, use_fp16): - # import tensorrt as trt - # trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan' - # trt_file_path = os.path.join(model_dir, trt_file_name) - # if not os.path.isfile(trt_file_path): - # raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file" - - # trt.init_libnvinfer_plugins(None, "") - # logger = trt.Logger(trt.Logger.WARNING) - # runtime = trt.Runtime(logger) - # with open(trt_file_path, 'rb') as f: - # serialized_engine = f.read() - # engine = runtime.deserialize_cuda_engine(serialized_engine) - # self.flow.decoder.estimator_context = engine.create_execution_context() - # self.flow.decoder.estimator = None - - def load_onnx(self, model_dir, use_fp16): - onnx_file_name = 'estimator_fp16.onnx' if use_fp16 else 'estimator_fp32.onnx' - onnx_file_path = os.path.join(model_dir, onnx_file_name) - if not os.path.isfile(onnx_file_path): - raise f"{onnx_file_path} does not exist. Please use bin/export_trt.py to generate .onnx file" - - providers = ['CUDAExecutionProvider'] - sess_options = ort.SessionOptions() - - # Add TensorRT Execution Provider - providers = [ - 'CUDAExecutionProvider' - ] - - # Load the ONNX model - self.flow.decoder.session = ort.InferenceSession(onnx_file_path, sess_options=sess_options, providers=providers) - # self.flow.decoder.estimator_context = None - self.flow.decoder.estimator = None - + def load_onnx(self, flow_decoder_estimator_model): + import onnxruntime + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + del self.flow.decoder.estimator + self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers) def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: @@ -207,4 +180,5 @@ class CosyVoiceModel: self.llm_end_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 82e3196..e42facd 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -31,8 +31,6 @@ class ConditionalCFM(BASECFM): in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) # Just change the architecture of the estimator here self.estimator = estimator - self.estimator_context = None # for tensorrt - self.session = None # for onnx @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): @@ -82,10 +80,10 @@ class ConditionalCFM(BASECFM): sol = [] for step in range(1, len(t_span)): - dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond) # Classifier-Free Guidance inference introduced in VoiceBox if self.inference_cfg_rate > 0: - cfg_dphi_dt = self.estimator( + cfg_dphi_dt = self.forward_estimator( x, mask, torch.zeros_like(mu), t, torch.zeros_like(spks) if spks is not None else None, @@ -102,51 +100,20 @@ class ConditionalCFM(BASECFM): return sol[-1] def forward_estimator(self, x, mask, mu, t, spks, cond): - - if self.estimator is not None: + if isinstance(self.estimator, torch.nn.Module): return self.estimator.forward(x, mask, mu, t, spks, cond) - # elif self.estimator_context is not None: - # assert self.training is False, 'tensorrt cannot be used in training' - # bs = x.shape[0] - # hs = x.shape[1] - # seq_len = x.shape[2] - # # assert bs == 1 and hs == 80 - # ret = torch.empty_like(x) - # self.estimator_context.set_input_shape("x", x.shape) - # self.estimator_context.set_input_shape("mask", mask.shape) - # self.estimator_context.set_input_shape("mu", mu.shape) - # self.estimator_context.set_input_shape("t", t.shape) - # self.estimator_context.set_input_shape("spks", spks.shape) - # self.estimator_context.set_input_shape("cond", cond.shape) - - # # Create a list of bindings - # bindings = [int(x.data_ptr()), int(mask.data_ptr()), int(mu.data_ptr()), int(t.data_ptr()), int(spks.data_ptr()), int(cond.data_ptr()), int(ret.data_ptr())] - - # # Execute the inference - # self.estimator_context.execute_v2(bindings=bindings) - # return ret else: - x_np = x.cpu().numpy() - mask_np = mask.cpu().numpy() - mu_np = mu.cpu().numpy() - t_np = t.cpu().numpy() - spks_np = spks.cpu().numpy() - cond_np = cond.cpu().numpy() - ort_inputs = { - 'x': x_np, - 'mask': mask_np, - 'mu': mu_np, - 't': t_np, - 'spks': spks_np, - 'cond': cond_np + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() } - - output = self.session.run(None, ort_inputs)[0] - + output = self.estimator.run(None, ort_inputs)[0] return torch.tensor(output, dtype=x.dtype, device=x.device) - def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss diff --git a/requirements.txt b/requirements.txt index c7a7f7d..9782ca3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ matplotlib==3.7.5 modelscope==1.15.0 networkx==3.1 omegaconf==2.3.0 +onnx==1.16.0 onnxruntime-gpu==1.16.0; sys_platform == 'linux' onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' openai-whisper==20231117 From 7555afb90adfc21c95ca06667713ca8754efb3d1 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 5 Sep 2024 14:09:40 +0800 Subject: [PATCH 09/11] update fastapi --- cosyvoice/cli/cosyvoice.py | 9 +- cosyvoice/hifigan/generator.py | 2 +- examples/libritts/cosyvoice/run.sh | 6 + examples/magicdata-read/cosyvoice/run.sh | 6 + runtime/python/fastapi/client.py | 92 +++++++------ runtime/python/fastapi/server.py | 164 +++++++++-------------- 6 files changed, 131 insertions(+), 148 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index eab5cad..5e1ea9c 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -13,6 +13,7 @@ # limitations under the License. import os import time +from tqdm import tqdm from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download from cosyvoice.cli.frontend import CosyVoiceFrontEnd @@ -52,7 +53,7 @@ class CosyVoice: return spks def inference_sft(self, tts_text, spk_id, stream=False): - for i in self.frontend.text_normalize(tts_text, split=True): + for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): model_input = self.frontend.frontend_sft(i, spk_id) start_time = time.time() logging.info('synthesis text {}'.format(i)) @@ -64,7 +65,7 @@ class CosyVoice: def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False): prompt_text = self.frontend.text_normalize(prompt_text, split=False) - for i in self.frontend.text_normalize(tts_text, split=True): + for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) start_time = time.time() logging.info('synthesis text {}'.format(i)) @@ -77,7 +78,7 @@ class CosyVoice: def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False): if self.frontend.instruct is True: raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir)) - for i in self.frontend.text_normalize(tts_text, split=True): + for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) start_time = time.time() logging.info('synthesis text {}'.format(i)) @@ -91,7 +92,7 @@ class CosyVoice: if self.frontend.instruct is False: raise ValueError('{} do not support instruct inference'.format(self.model_dir)) instruct_text = self.frontend.text_normalize(instruct_text, split=False) - for i in self.frontend.text_normalize(tts_text, split=True): + for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) start_time = time.time() logging.info('synthesis text {}'.format(i)) diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index fd61834..b640219 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -340,7 +340,7 @@ class HiFTGenerator(nn.Module): s = self._f02source(f0) # use cache_source to avoid glitch - if cache_source.shape[2] == 0: + if cache_source.shape[2] != 0: s[:, :, :cache_source.shape[2]] = cache_source s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) diff --git a/examples/libritts/cosyvoice/run.sh b/examples/libritts/cosyvoice/run.sh index 96eca9b..386e9e4 100644 --- a/examples/libritts/cosyvoice/run.sh +++ b/examples/libritts/cosyvoice/run.sh @@ -102,4 +102,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --deepspeed_config ./conf/ds_stage2.json \ --deepspeed.save_states model+optimizer done +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" + python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir + python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir fi \ No newline at end of file diff --git a/examples/magicdata-read/cosyvoice/run.sh b/examples/magicdata-read/cosyvoice/run.sh index 0cf6f6d..0a080ac 100644 --- a/examples/magicdata-read/cosyvoice/run.sh +++ b/examples/magicdata-read/cosyvoice/run.sh @@ -102,4 +102,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --deepspeed_config ./conf/ds_stage2.json \ --deepspeed.save_states model+optimizer done +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" + python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir + python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir fi \ No newline at end of file diff --git a/runtime/python/fastapi/client.py b/runtime/python/fastapi/client.py index cf32092..981c7c1 100644 --- a/runtime/python/fastapi/client.py +++ b/runtime/python/fastapi/client.py @@ -1,56 +1,68 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import logging import requests +import torch +import torchaudio +import numpy as np -def saveResponse(path, response): - # 以二进制写入模式打开文件 - with open(path, 'wb') as file: - # 将响应的二进制内容写入文件 - file.write(response.content) def main(): - api = args.api_base + url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode) if args.mode == 'sft': - url = api + "/api/inference/sft" - payload={ - 'tts': args.tts_text, - 'role': args.spk_id - } - response = requests.request("POST", url, data=payload) - saveResponse(args.tts_wav, response) - elif args.mode == 'zero_shot': - url = api + "/api/inference/zero-shot" - payload={ - 'tts': args.tts_text, - 'prompt': args.prompt_text - } - files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))] - response = requests.request("POST", url, data=payload, files=files) - saveResponse(args.tts_wav, response) - elif args.mode == 'cross_lingual': - url = api + "/api/inference/cross-lingual" - payload={ - 'tts': args.tts_text, - } - files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))] - response = requests.request("POST", url, data=payload, files=files) - saveResponse(args.tts_wav, response) - else: - url = api + "/api/inference/instruct" payload = { - 'tts': args.tts_text, - 'role': args.spk_id, - 'instruct': args.instruct_text + 'tts_text': args.tts_text, + 'spk_id': args.spk_id } - response = requests.request("POST", url, data=payload) - saveResponse(args.tts_wav, response) - logging.info("Response save to {}", args.tts_wav) + response = requests.request("GET", url, data=payload, stream=True) + elif args.mode == 'zero_shot': + payload = { + 'tts_text': args.tts_text, + 'prompt_text': args.prompt_text + } + files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] + response = requests.request("GET", url, data=payload, files=files, stream=True) + elif args.mode == 'cross_lingual': + payload = { + 'tts_text': args.tts_text, + } + files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))] + response = requests.request("GET", url, data=payload, files=files, stream=True) + else: + payload = { + 'tts_text': args.tts_text, + 'spk_id': args.spk_id, + 'instruct_text': args.instruct_text + } + response = requests.request("GET", url, data=payload, stream=True) + tts_audio = b'' + for r in response.iter_content(chunk_size=16000): + tts_audio += r + tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) + logging.info('save response to {}'.format(args.tts_wav)) + torchaudio.save(args.tts_wav, tts_speech, target_sr) + logging.info('get response') if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--api_base', + parser.add_argument('--host', type=str, - default='http://127.0.0.1:6006') + default='0.0.0.0') + parser.add_argument('--port', + type=int, + default='50000') parser.add_argument('--mode', default='sft', choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], diff --git a/runtime/python/fastapi/server.py b/runtime/python/fastapi/server.py index b670665..c540b47 100644 --- a/runtime/python/fastapi/server.py +++ b/runtime/python/fastapi/server.py @@ -1,119 +1,77 @@ -# Set inference model -# export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct -# For development -# fastapi dev --port 6006 fastapi_server.py -# For production deployment -# fastapi run --port 6006 fastapi_server.py - +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import sys -import io,time -from fastapi import FastAPI, Response, File, UploadFile, Form -from fastapi.responses import HTMLResponse -from fastapi.middleware.cors import CORSMiddleware #引入 CORS中间件模块 -from contextlib import asynccontextmanager ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) -from cosyvoice.cli.cosyvoice import CosyVoice -from cosyvoice.utils.file_utils import load_wav -import numpy as np -import torch -import torchaudio +import argparse import logging logging.getLogger('matplotlib').setLevel(logging.WARNING) +from fastapi import FastAPI, UploadFile, Form, File +from fastapi.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +import numpy as np +from cosyvoice.cli.cosyvoice import CosyVoice +from cosyvoice.utils.file_utils import load_wav -class LaunchFailed(Exception): - pass - -@asynccontextmanager -async def lifespan(app: FastAPI): - model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT") - if model_dir: - logging.info("MODEL_DIR is {}", model_dir) - app.cosyvoice = CosyVoice(model_dir) - # sft usage - logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks()) - else: - raise LaunchFailed("MODEL_DIR environment must set") - yield - -app = FastAPI(lifespan=lifespan) - -#设置允许访问的域名 -origins = ["*"] #"*",即为所有,也可以改为允许的特定ip。 +app = FastAPI() +# set cross region allowance app.add_middleware( - CORSMiddleware, - allow_origins=origins, #设置允许的origins来源 + CORSMiddleware, + allow_origins=["*"], allow_credentials=True, - allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。 - allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。 + allow_methods=["*"], + allow_headers=["*"]) -def buildResponse(output): - buffer = io.BytesIO() - torchaudio.save(buffer, output, 22050, format="wav") - buffer.seek(0) - return Response(content=buffer.read(-1), media_type="audio/wav") +def generate_data(model_output): + for i in model_output: + tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() + yield tts_audio -@app.post("/api/inference/sft") -@app.get("/api/inference/sft") -async def sft(tts: str = Form(), role: str = Form()): - start = time.process_time() - output = app.cosyvoice.inference_sft(tts, role) - end = time.process_time() - logging.info("infer time is {} seconds", end-start) - return buildResponse(output['tts_speech']) +@app.get("/inference_sft") +async def inference_sft(tts_text: str = Form(), spk_id: str = Form()): + model_output = cosyvoice.inference_sft(tts_text, spk_id) + return StreamingResponse(generate_data(model_output)) -@app.post("/api/inference/zero-shot") -async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()): - start = time.process_time() - prompt_speech = load_wav(audio.file, 16000) - prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() - prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0) - prompt_speech_16k = prompt_speech_16k.float() / (2**15) +@app.get("/inference_zero_shot") +async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()): + prompt_speech_16k = load_wav(prompt_wav.file, 16000) + model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k) + return StreamingResponse(generate_data(model_output)) - output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k) - end = time.process_time() - logging.info("infer time is {} seconds", end-start) - return buildResponse(output['tts_speech']) +@app.get("/inference_cross_lingual") +async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()): + prompt_speech_16k = load_wav(prompt_wav.file, 16000) + model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) + return StreamingResponse(generate_data(model_output)) -@app.post("/api/inference/cross-lingual") -async def crossLingual(tts: str = Form(), audio: UploadFile = File()): - start = time.process_time() - prompt_speech = load_wav(audio.file, 16000) - prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() - prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0) - prompt_speech_16k = prompt_speech_16k.float() / (2**15) +@app.get("/inference_instruct") +async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()): + model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text) + return StreamingResponse(generate_data(model_output)) - output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k) - end = time.process_time() - logging.info("infer time is {} seconds", end-start) - return buildResponse(output['tts_speech']) - -@app.post("/api/inference/instruct") -@app.get("/api/inference/instruct") -async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()): - start = time.process_time() - output = app.cosyvoice.inference_instruct(tts, role, instruct) - end = time.process_time() - logging.info("infer time is {} seconds", end-start) - return buildResponse(output['tts_speech']) - -@app.get("/api/roles") -async def roles(): - return {"roles": app.cosyvoice.list_avaliable_spks()} - -@app.get("/", response_class=HTMLResponse) -async def root(): - return """ - - - - - Api information - - - Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. Documents of API - - - """ +if __name__=='__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--port', + type=int, + default=50000) + parser.add_argument('--model_dir', + type=str, + default='iic/CosyVoice-300M', + help='local path or modelscope repo id') + args = parser.parse_args() + cosyvoice = CosyVoice(args.model_dir) + uvicorn.run(app, host="127.0.0.1", port=args.port) \ No newline at end of file From 11eacb810e27af2047f1b7fe9a37ad22d6180336 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 5 Sep 2024 14:12:27 +0800 Subject: [PATCH 10/11] add uvicorn requirements --- README.md | 2 +- requirements.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4d1a7f7..b9d0a6e 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ docker build -t cosyvoice:v1.0 . docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity" cd grpc && python3 client.py --port 50000 --mode # for fastapi usage -docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity" +docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity" cd fastapi && python3 client.py --port 50000 --mode ``` diff --git a/requirements.txt b/requirements.txt index 9782ca3..4189c5f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ soundfile==0.12.1 tensorboard==2.14.0 torch==2.0.1 torchaudio==2.0.2 +uvicorn==0.30.0 wget==3.2 fastapi==0.111.0 fastapi-cli==0.0.4 From e141634da18796717417a954936b454c32640d22 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 5 Sep 2024 14:26:12 +0800 Subject: [PATCH 11/11] remove unnecessary code --- cosyvoice/flow/flow.py | 2 +- cosyvoice/flow/flow_matching.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 10a0bf3..2d0a730 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -113,7 +113,7 @@ class MaskedDiffWithXvec(torch.nn.Module): # concat text and prompt_text token_len1, token_len2 = prompt_token.shape[1], token.shape[1] token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len - mask = (~make_pad_mask(token_len)).to(embedding.dtype).unsqueeze(-1).to(embedding) + mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) token = self.input_embedding(torch.clamp(token, min=0)) * mask # text encode diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index e42facd..7e31177 100755 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -14,8 +14,6 @@ import torch import torch.nn.functional as F from matcha.models.components.flow_matching import BASECFM -import onnxruntime as ort -import numpy as np class ConditionalCFM(BASECFM): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):