mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
@@ -54,11 +54,13 @@ class CosyVoice:
|
||||
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.estimator_count = configs.get('estimator_count', 1)
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
self.fp16)
|
||||
self.fp16, self.estimator_count)
|
||||
del configs
|
||||
|
||||
|
||||
def list_available_spks(self):
|
||||
spks = list(self.frontend.spk2info.keys())
|
||||
return spks
|
||||
@@ -178,11 +180,13 @@ class CosyVoice2(CosyVoice):
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.estimator_count = configs.get('estimator_count', 1)
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
self.fp16)
|
||||
self.fp16, self.estimator_count)
|
||||
del configs
|
||||
|
||||
|
||||
def inference_instruct(self, *args, **kwargs):
|
||||
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ from contextlib import nullcontext
|
||||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
||||
|
||||
from cosyvoice.flow.flow_matching import EstimatorWrapper
|
||||
import queue
|
||||
|
||||
class CosyVoiceModel:
|
||||
|
||||
@@ -66,6 +67,12 @@ class CosyVoiceModel:
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
self.stream_context_pool = queue.Queue()
|
||||
for _ in range(10):
|
||||
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
||||
|
||||
self.is_cuda_available = torch.cuda.is_available()
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
self.llm.to(self.device).eval()
|
||||
@@ -84,7 +91,7 @@ class CosyVoiceModel:
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16, estimator_count=1):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model):
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
|
||||
@@ -96,7 +103,7 @@ class CosyVoiceModel:
|
||||
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
if self.flow.decoder.estimator_engine is None:
|
||||
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
||||
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
||||
self.flow.decoder.estimator = EstimatorWrapper(self.flow.decoder.estimator_engine, estimator_count=estimator_count)
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context:
|
||||
@@ -166,6 +173,10 @@ class CosyVoiceModel:
|
||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
|
||||
stream_context = self.stream_context_pool.get()
|
||||
with stream_context:
|
||||
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
@@ -222,6 +233,9 @@ class CosyVoiceModel:
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
|
||||
self.synchronize_stream()
|
||||
self.stream_context_pool.put(stream_context)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
||||
@@ -278,6 +292,10 @@ class CosyVoiceModel:
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def synchronize_stream(self):
|
||||
if self.is_cuda_available:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
|
||||
class CosyVoice2Model(CosyVoiceModel):
|
||||
|
||||
@@ -314,11 +332,18 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
self.stream_context_pool = queue.Queue()
|
||||
for _ in range(10):
|
||||
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
||||
|
||||
self.is_cuda_available = torch.cuda.is_available()
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
@@ -358,6 +383,10 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
self.synchronize_stream()
|
||||
stream_context = self.stream_context_pool.get()
|
||||
with torch.cuda.stream(stream_context):
|
||||
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
@@ -408,6 +437,9 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
|
||||
self.synchronize_stream()
|
||||
self.stream_context_pool.put(stream_context)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,26 @@ import threading
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
import queue
|
||||
|
||||
class EstimatorWrapper:
|
||||
def __init__(self, estimator_engine, estimator_count=2,):
|
||||
self.estimators = queue.Queue()
|
||||
self.estimator_engine = estimator_engine
|
||||
for _ in range(estimator_count):
|
||||
estimator = estimator_engine.create_execution_context()
|
||||
if estimator is not None:
|
||||
self.estimators.put(estimator)
|
||||
|
||||
if self.estimators.empty():
|
||||
raise Exception("No available estimator")
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.estimators.get(), self.estimator_engine
|
||||
|
||||
def release_estimator(self, estimator):
|
||||
self.estimators.put(estimator)
|
||||
return
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
@@ -124,6 +143,34 @@ class ConditionalCFM(BASECFM):
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
else:
|
||||
if isinstance(self.estimator, EstimatorWrapper):
|
||||
estimator, engine = self.estimator.acquire_estimator()
|
||||
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
|
||||
for idx, data_ptr in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
|
||||
|
||||
# run trt engine
|
||||
estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator)
|
||||
return x
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
|
||||
@@ -61,7 +61,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
|
||||
Reference in New Issue
Block a user