diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 68379a4..2da3d0a 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -53,7 +53,9 @@ class CosyVoice: '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: - self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + self.fp16) del configs def list_available_spks(self): @@ -149,7 +151,9 @@ class CosyVoice2(CosyVoice): if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: - self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + self.fp16) del configs def inference_instruct(self, *args, **kwargs): diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 0840e46..20de439 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import numpy as np import threading @@ -19,6 +20,7 @@ from torch.nn import functional as F from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out +from cosyvoice.utils.file_utils import convert_onnx_to_trt class CosyVoiceModel: @@ -35,6 +37,9 @@ class CosyVoiceModel: self.fp16 = fp16 self.llm.fp16 = fp16 self.flow.fp16 = fp16 + if self.fp16 is True: + self.llm.half() + self.flow.half() self.token_min_hop_len = 2 * self.flow.input_frame_rate self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_overlap_len = 20 @@ -69,9 +74,6 @@ class CosyVoiceModel: hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.to(self.device).eval() - if self.fp16 is True: - self.llm.half() - self.flow.half() def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) @@ -81,7 +83,10 @@ class CosyVoiceModel: flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def load_trt(self, flow_decoder_estimator_model): + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model): + convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16) del self.flow.decoder.estimator import tensorrt as trt with open(flow_decoder_estimator_model, 'rb') as f: @@ -204,6 +209,7 @@ class CosyVoiceModel: self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) + torch.cuda.empty_cache() def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread @@ -257,6 +263,7 @@ class CosyVoiceModel: self.llm_end_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) + torch.cuda.empty_cache() class CosyVoice2Model(CosyVoiceModel): @@ -273,6 +280,9 @@ class CosyVoice2Model(CosyVoiceModel): self.fp16 = fp16 self.llm.fp16 = fp16 self.flow.fp16 = fp16 + if self.fp16 is True: + self.llm.half() + self.flow.half() self.token_hop_len = 2 * self.flow.input_frame_rate # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate @@ -385,3 +395,4 @@ class CosyVoice2Model(CosyVoiceModel): with self.lock: self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) + torch.cuda.empty_cache() diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index e0d3979..3c0b176 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -21,7 +21,6 @@ import torchaudio from torch.nn.utils.rnn import pad_sequence import torch.nn.functional as F -torchaudio.set_audio_backend('soundfile') AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index f4e0ace..6a60f6d 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -134,12 +134,12 @@ class ConditionalCFM(BASECFM): self.estimator.set_input_shape('cond', (2, 80, x.size(2))) # run trt engine self.estimator.execute_v2([x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()]) + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()]) return x def compute_loss(self, x1, mask, mu, spks=None, cond=None): diff --git a/cosyvoice/hifigan/discriminator.py b/cosyvoice/hifigan/discriminator.py index 6fc7845..1a4dcc8 100644 --- a/cosyvoice/hifigan/discriminator.py +++ b/cosyvoice/hifigan/discriminator.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from torch.nn.utils import weight_norm +from torch.nn.utils.parametrizations import weight_norm from typing import List, Optional, Tuple from einops import rearrange from torchaudio.transforms import Spectrogram diff --git a/cosyvoice/hifigan/f0_predictor.py b/cosyvoice/hifigan/f0_predictor.py index 36b85f4..172c5f5 100644 --- a/cosyvoice/hifigan/f0_predictor.py +++ b/cosyvoice/hifigan/f0_predictor.py @@ -13,7 +13,7 @@ # limitations under the License. import torch import torch.nn as nn -from torch.nn.utils import weight_norm +from torch.nn.utils.parametrizations import weight_norm class ConvRNNF0Predictor(nn.Module): diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index 0098b90..c47bf05 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from torch.nn import Conv1d from torch.nn import ConvTranspose1d from torch.nn.utils import remove_weight_norm -from torch.nn.utils import weight_norm +from torch.nn.utils.parametrizations import weight_norm from torch.distributions.uniform import Uniform from cosyvoice.transformer.activation import Snake diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index 7e81e3d..3131769 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -1,5 +1,5 @@ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) -# 2024 Alibaba Inc (authors: Xiang Lyu) +# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # limitations under the License. import json +import tensorrt as trt import torchaudio import logging logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -45,3 +46,44 @@ def load_wav(wav, target_sr): assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) return speech + + +def convert_onnx_to_trt(trt_model, onnx_model, fp16): + _min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)] + _opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)] + _max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)] + input_names = ["x", "mask", "mu", "t", "spks", "cond"] + + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + profile = builder.create_optimization_profile() + # load onnx model + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('failed to parse {}'.format(onnx_model)) + # set input shapes + for i in range(len(input_names)): + profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i]) + tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT + # set input and output data type + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + # save trt engine + with open(trt_model, "wb") as f: + f.write(engine_bytes) diff --git a/runtime/python/fastapi/server.py b/runtime/python/fastapi/server.py index bfe4a56..17aed2f 100644 --- a/runtime/python/fastapi/server.py +++ b/runtime/python/fastapi/server.py @@ -24,7 +24,7 @@ import numpy as np ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) -from cosyvoice.cli.cosyvoice import CosyVoice +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.utils.file_utils import load_wav app = FastAPI() @@ -79,5 +79,11 @@ if __name__ == '__main__': default='iic/CosyVoice-300M', help='local path or modelscope repo id') args = parser.parse_args() - cosyvoice = CosyVoice(args.model_dir) + try: + cosyvoice = CosyVoice(args.model_dir) + except Exception: + try: + cosyvoice = CosyVoice2(args.model_dir) + except Exception: + raise TypeError('no valid model_type!') uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/runtime/python/grpc/server.py b/runtime/python/grpc/server.py index 325fadc..1cb48ae 100644 --- a/runtime/python/grpc/server.py +++ b/runtime/python/grpc/server.py @@ -25,7 +25,7 @@ import numpy as np ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) -from cosyvoice.cli.cosyvoice import CosyVoice +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') @@ -33,7 +33,13 @@ logging.basicConfig(level=logging.DEBUG, class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): def __init__(self, args): - self.cosyvoice = CosyVoice(args.model_dir) + try: + self.cosyvoice = CosyVoice(args.model_dir) + except Exception: + try: + self.cosyvoice = CosyVoice2(args.model_dir) + except Exception: + raise TypeError('no valid model_type!') logging.info('grpc service initialized') def Inference(self, request, context): diff --git a/webui.py b/webui.py index 6c310d5..e437414 100644 --- a/webui.py +++ b/webui.py @@ -184,7 +184,14 @@ if __name__ == '__main__': default='pretrained_models/CosyVoice2-0.5B', help='local path or modelscope repo id') args = parser.parse_args() - cosyvoice = CosyVoice2(args.model_dir) if 'CosyVoice2' in args.model_dir else CosyVoice(args.model_dir) + try: + cosyvoice = CosyVoice(args.model_dir) + except Exception: + try: + cosyvoice = CosyVoice2(args.model_dir) + except Exception: + raise TypeError('no valid model_type!') + sft_spk = cosyvoice.list_available_spks() prompt_sr = 16000 default_data = np.zeros(cosyvoice.sample_rate)