Merge pull request #865 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
This commit is contained in:
Xiang Lyu
2025-01-10 14:18:21 +08:00
committed by GitHub
11 changed files with 97 additions and 22 deletions

View File

@@ -53,7 +53,9 @@ class CosyVoice:
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt: if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
self.fp16)
del configs del configs
def list_available_spks(self): def list_available_spks(self):
@@ -149,7 +151,9 @@ class CosyVoice2(CosyVoice):
if load_jit: if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt: if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
self.fp16)
del configs del configs
def inference_instruct(self, *args, **kwargs): def inference_instruct(self, *args, **kwargs):

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import torch import torch
import numpy as np import numpy as np
import threading import threading
@@ -19,6 +20,7 @@ from torch.nn import functional as F
from contextlib import nullcontext from contextlib import nullcontext
import uuid import uuid
from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt
class CosyVoiceModel: class CosyVoiceModel:
@@ -35,6 +37,9 @@ class CosyVoiceModel:
self.fp16 = fp16 self.fp16 = fp16
self.llm.fp16 = fp16 self.llm.fp16 = fp16
self.flow.fp16 = fp16 self.flow.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_min_hop_len = 2 * self.flow.input_frame_rate self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20 self.token_overlap_len = 20
@@ -69,9 +74,6 @@ class CosyVoiceModel:
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval() self.hift.to(self.device).eval()
if self.fp16 is True:
self.llm.half()
self.flow.half()
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
@@ -81,7 +83,10 @@ class CosyVoiceModel:
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
def load_trt(self, flow_decoder_estimator_model): def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model):
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator del self.flow.decoder.estimator
import tensorrt as trt import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f: with open(flow_decoder_estimator_model, 'rb') as f:
@@ -204,6 +209,7 @@ class CosyVoiceModel:
self.mel_overlap_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs): def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread # this_uuid is used to track variables related to this inference thread
@@ -257,6 +263,7 @@ class CosyVoiceModel:
self.llm_end_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
class CosyVoice2Model(CosyVoiceModel): class CosyVoice2Model(CosyVoiceModel):
@@ -273,6 +280,9 @@ class CosyVoice2Model(CosyVoiceModel):
self.fp16 = fp16 self.fp16 = fp16
self.llm.fp16 = fp16 self.llm.fp16 = fp16
self.flow.fp16 = fp16 self.flow.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_hop_len = 2 * self.flow.input_frame_rate self.token_hop_len = 2 * self.flow.input_frame_rate
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
@@ -385,3 +395,4 @@ class CosyVoice2Model(CosyVoiceModel):
with self.lock: with self.lock:
self.tts_speech_token_dict.pop(this_uuid) self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid)
torch.cuda.empty_cache()

View File

@@ -21,7 +21,6 @@ import torchaudio
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F import torch.nn.functional as F
torchaudio.set_audio_backend('soundfile')
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}

View File

@@ -134,12 +134,12 @@ class ConditionalCFM(BASECFM):
self.estimator.set_input_shape('cond', (2, 80, x.size(2))) self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine # run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(), self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(), mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(), mu.contiguous().data_ptr(),
t.contiguous().data_ptr(), t.contiguous().data_ptr(),
spks.contiguous().data_ptr(), spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(), cond.contiguous().data_ptr(),
x.data_ptr()]) x.data_ptr()])
return x return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None): def compute_loss(self, x1, mask, mu, spks=None, cond=None):

View File

@@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from einops import rearrange from einops import rearrange
from torchaudio.transforms import Spectrogram from torchaudio.transforms import Spectrogram

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
class ConvRNNF0Predictor(nn.Module): class ConvRNNF0Predictor(nn.Module):

View File

@@ -23,7 +23,7 @@ import torch.nn.functional as F
from torch.nn import Conv1d from torch.nn import Conv1d
from torch.nn import ConvTranspose1d from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm from torch.nn.utils import remove_weight_norm
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.distributions.uniform import Uniform from torch.distributions.uniform import Uniform
from cosyvoice.transformer.activation import Snake from cosyvoice.transformer.activation import Snake

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu) # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import json import json
import tensorrt as trt
import torchaudio import torchaudio
import logging import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING)
@@ -45,3 +46,44 @@ def load_wav(wav, target_sr):
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech return speech
def convert_onnx_to_trt(trt_model, onnx_model, fp16):
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes
for i in range(len(input_names)):
profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
# set input and output data type
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_tensor.dtype = tensor_dtype
for i in range(network.num_outputs):
output_tensor = network.get_output(i)
output_tensor.dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
# save trt engine
with open(trt_model, "wb") as f:
f.write(engine_bytes)

View File

@@ -24,7 +24,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav from cosyvoice.utils.file_utils import load_wav
app = FastAPI() app = FastAPI()
@@ -79,5 +79,11 @@ if __name__ == '__main__':
default='iic/CosyVoice-300M', default='iic/CosyVoice-300M',
help='local path or modelscope repo id') help='local path or modelscope repo id')
args = parser.parse_args() args = parser.parse_args()
cosyvoice = CosyVoice(args.model_dir) try:
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
uvicorn.run(app, host="0.0.0.0", port=args.port) uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@@ -25,7 +25,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
logging.basicConfig(level=logging.DEBUG, logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s') format='%(asctime)s %(levelname)s %(message)s')
@@ -33,7 +33,13 @@ logging.basicConfig(level=logging.DEBUG,
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
def __init__(self, args): def __init__(self, args):
self.cosyvoice = CosyVoice(args.model_dir) try:
self.cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
self.cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
logging.info('grpc service initialized') logging.info('grpc service initialized')
def Inference(self, request, context): def Inference(self, request, context):

View File

@@ -184,7 +184,14 @@ if __name__ == '__main__':
default='pretrained_models/CosyVoice2-0.5B', default='pretrained_models/CosyVoice2-0.5B',
help='local path or modelscope repo id') help='local path or modelscope repo id')
args = parser.parse_args() args = parser.parse_args()
cosyvoice = CosyVoice2(args.model_dir) if 'CosyVoice2' in args.model_dir else CosyVoice(args.model_dir) try:
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
sft_spk = cosyvoice.list_available_spks() sft_spk = cosyvoice.list_available_spks()
prompt_sr = 16000 prompt_sr = 16000
default_data = np.zeros(cosyvoice.sample_rate) default_data = np.zeros(cosyvoice.sample_rate)