add online trt export

This commit is contained in:
lyuxiang.lx
2025-01-10 13:55:05 +08:00
parent 426c4001ca
commit 1cfc5dd077
13 changed files with 100 additions and 167 deletions

View File

@@ -128,8 +128,6 @@ import torchaudio
**CosyVoice2 Usage** **CosyVoice2 Usage**
```python ```python
# NOTE if you want to use tensorRT to accerlate the flow matching inference, please set load_trt=True.
# if you don't want to save tensorRT model on disk, please set environment variable `NOT_SAVE_TRT=1`.
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False) cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference

View File

@@ -53,7 +53,9 @@ class CosyVoice:
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt: if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
self.fp16)
del configs del configs
def list_available_spks(self): def list_available_spks(self):
@@ -149,7 +151,9 @@ class CosyVoice2(CosyVoice):
if load_jit: if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt: if load_trt:
self.model.load_trt('{}/flow.decoder.estimator'.format(model_dir), self.fp16) self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
self.fp16)
del configs del configs
def inference_instruct(self, *args, **kwargs): def inference_instruct(self, *args, **kwargs):

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import torch import torch
import numpy as np import numpy as np
import threading import threading
@@ -19,7 +20,7 @@ from torch.nn import functional as F
from contextlib import nullcontext from contextlib import nullcontext
import uuid import uuid
from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.common import fade_in_out
from cosyvoice.trt.estimator_trt import EstimatorTRT from cosyvoice.utils.file_utils import convert_onnx_to_trt
class CosyVoiceModel: class CosyVoiceModel:
@@ -36,6 +37,9 @@ class CosyVoiceModel:
self.fp16 = fp16 self.fp16 = fp16
self.llm.fp16 = fp16 self.llm.fp16 = fp16
self.flow.fp16 = fp16 self.flow.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_min_hop_len = 2 * self.flow.input_frame_rate self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20 self.token_overlap_len = 20
@@ -70,9 +74,6 @@ class CosyVoiceModel:
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval() self.hift.to(self.device).eval()
if self.fp16 is True:
self.llm.half()
self.flow.half()
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
@@ -82,9 +83,17 @@ class CosyVoiceModel:
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
def load_trt(self, flow_decoder_estimator_model, fp16): def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model):
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator del self.flow.decoder.estimator
self.flow.decoder.estimator = EstimatorTRT(flow_decoder_estimator_model, self.device, fp16) import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
if self.flow.decoder.estimator_engine is None:
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context: with self.llm_context:
@@ -269,6 +278,9 @@ class CosyVoice2Model(CosyVoiceModel):
self.fp16 = fp16 self.fp16 = fp16
self.llm.fp16 = fp16 self.llm.fp16 = fp16
self.flow.fp16 = fp16 self.flow.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_hop_len = 2 * self.flow.input_frame_rate self.token_hop_len = 2 * self.flow.input_frame_rate
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate

View File

@@ -21,7 +21,6 @@ import torchaudio
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F import torch.nn.functional as F
torchaudio.set_audio_backend('soundfile')
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}

View File

@@ -134,12 +134,12 @@ class ConditionalCFM(BASECFM):
self.estimator.set_input_shape('cond', (2, 80, x.size(2))) self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine # run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(), self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(), mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(), mu.contiguous().data_ptr(),
t.contiguous().data_ptr(), t.contiguous().data_ptr(),
spks.contiguous().data_ptr(), spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(), cond.contiguous().data_ptr(),
x.data_ptr()]) x.data_ptr()])
return x return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None): def compute_loss(self, x1, mask, mu, spks=None, cond=None):

View File

@@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from einops import rearrange from einops import rearrange
from torchaudio.transforms import Spectrogram from torchaudio.transforms import Spectrogram

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
class ConvRNNF0Predictor(nn.Module): class ConvRNNF0Predictor(nn.Module):

View File

@@ -23,7 +23,7 @@ import torch.nn.functional as F
from torch.nn import Conv1d from torch.nn import Conv1d
from torch.nn import ConvTranspose1d from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm from torch.nn.utils import remove_weight_norm
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.distributions.uniform import Uniform from torch.distributions.uniform import Uniform
from cosyvoice.transformer.activation import Snake from cosyvoice.transformer.activation import Snake

View File

@@ -1,141 +0,0 @@
import os
import torch
import tensorrt as trt
import logging
import threading
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
class EstimatorTRT:
def __init__(self, path_prefix: str, device: torch.device, fp16: bool = True):
self.lock = threading.Lock()
self.device = device
with torch.cuda.device(device):
self.input_names = ["x", "mask", "mu", "t", "spks", "cond"]
self.output_name = "estimator_out"
onnx_path = path_prefix + ".fp32.onnx"
precision = ".fp16" if fp16 else ".fp32"
trt_path = path_prefix + precision +".plan"
self.fp16 = fp16
self.logger = trt.Logger(trt.Logger.INFO)
self.trt_runtime = trt.Runtime(self.logger)
save_trt = not os.environ.get("NOT_SAVE_TRT", "0") == "1"
if os.path.exists(trt_path):
self.engine = self._load_trt(trt_path)
else:
self.engine = self._convert_onnx_to_trt(onnx_path, trt_path, save_trt)
self.context = self.engine.create_execution_context()
def _convert_onnx_to_trt(
self, onnx_path: str, trt_path: str, save_trt: bool = True
):
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
builder = trt.Builder(self.logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, self.logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
if (self.fp16):
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
exit(1)
# set input shapes
for i in range(len(self.input_names)):
profile.set_shape(
self.input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i]
)
tensor_dtype = trt.DataType.HALF if self.fp16 else trt.DataType.FLOAT
# set input and output data type
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_tensor.dtype = tensor_dtype
for i in range(network.num_outputs):
output_tensor = network.get_output(i)
output_tensor.dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
# save trt engine
if save_trt:
with open(trt_path, "wb") as f:
f.write(engine_bytes)
print("trt engine saved to {}".format(trt_path))
engine = self.trt_runtime.deserialize_cuda_engine(engine_bytes)
return engine
def _load_trt(self, trt_path: str):
logging.info("Found trt engine, loading...")
with open(trt_path, "rb") as f:
engine_bytes = f.read()
engine = self.trt_runtime.deserialize_cuda_engine(engine_bytes)
return engine
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
spks: torch.Tensor,
cond: torch.Tensor,
):
with self.lock:
with torch.cuda.device(self.device):
self.context.set_input_shape("x", (2, 80, x.size(2)))
self.context.set_input_shape("mask", (2, 1, x.size(2)))
self.context.set_input_shape("mu", (2, 80, x.size(2)))
self.context.set_input_shape("t", (2,))
self.context.set_input_shape("spks", (2, 80))
self.context.set_input_shape("cond", (2, 80, x.size(2)))
# run trt engine
self.context.execute_v2(
[
x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr(),
]
)
return x
def __call__(
self,
x: torch.Tensor,
mask: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
spks: torch.Tensor,
cond: torch.Tensor,
):
return self.forward(x, mask, mu, t, spks, cond)

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu) # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import json import json
import tensorrt as trt
import torchaudio import torchaudio
import logging import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING)
@@ -45,3 +46,44 @@ def load_wav(wav, target_sr):
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech return speech
def convert_onnx_to_trt(trt_model, onnx_model, fp16):
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes
for i in range(len(input_names)):
profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
# set input and output data type
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_tensor.dtype = tensor_dtype
for i in range(network.num_outputs):
output_tensor = network.get_output(i)
output_tensor.dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
# save trt engine
with open(trt_model, "wb") as f:
f.write(engine_bytes)

View File

@@ -24,7 +24,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav from cosyvoice.utils.file_utils import load_wav
app = FastAPI() app = FastAPI()
@@ -79,5 +79,11 @@ if __name__ == '__main__':
default='iic/CosyVoice-300M', default='iic/CosyVoice-300M',
help='local path or modelscope repo id') help='local path or modelscope repo id')
args = parser.parse_args() args = parser.parse_args()
cosyvoice = CosyVoice(args.model_dir) try:
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
uvicorn.run(app, host="0.0.0.0", port=args.port) uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@@ -25,7 +25,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
logging.basicConfig(level=logging.DEBUG, logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s') format='%(asctime)s %(levelname)s %(message)s')
@@ -33,7 +33,13 @@ logging.basicConfig(level=logging.DEBUG,
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
def __init__(self, args): def __init__(self, args):
self.cosyvoice = CosyVoice(args.model_dir) try:
self.cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
self.cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
logging.info('grpc service initialized') logging.info('grpc service initialized')
def Inference(self, request, context): def Inference(self, request, context):

View File

@@ -184,7 +184,14 @@ if __name__ == '__main__':
default='pretrained_models/CosyVoice2-0.5B', default='pretrained_models/CosyVoice2-0.5B',
help='local path or modelscope repo id') help='local path or modelscope repo id')
args = parser.parse_args() args = parser.parse_args()
cosyvoice = CosyVoice2(args.model_dir) if 'CosyVoice2' in args.model_dir else CosyVoice(args.model_dir) try:
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
sft_spk = cosyvoice.list_available_spks() sft_spk = cosyvoice.list_available_spks()
prompt_sr = 16000 prompt_sr = 16000
default_data = np.zeros(cosyvoice.sample_rate) default_data = np.zeros(cosyvoice.sample_rate)