Merge pull request #1184 from hexisyztem/dev/Comet

Dev/comet
This commit is contained in:
Xiang Lyu
2025-04-16 15:02:46 +08:00
committed by GitHub
4 changed files with 230 additions and 147 deletions

View File

@@ -54,11 +54,13 @@ class CosyVoice:
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt: if load_trt:
self.estimator_count = configs.get('estimator_count', 1)
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
self.fp16) self.fp16, self.estimator_count)
del configs del configs
def list_available_spks(self): def list_available_spks(self):
spks = list(self.frontend.spk2info.keys()) spks = list(self.frontend.spk2info.keys())
return spks return spks
@@ -178,11 +180,13 @@ class CosyVoice2(CosyVoice):
if load_jit: if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt: if load_trt:
self.estimator_count = configs.get('estimator_count', 1)
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
self.fp16) self.fp16, self.estimator_count)
del configs del configs
def inference_instruct(self, *args, **kwargs): def inference_instruct(self, *args, **kwargs):
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!') raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')

View File

@@ -22,7 +22,8 @@ from contextlib import nullcontext
import uuid import uuid
from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt from cosyvoice.utils.file_utils import convert_onnx_to_trt
from cosyvoice.flow.flow_matching import EstimatorWrapper
import queue
class CosyVoiceModel: class CosyVoiceModel:
@@ -66,6 +67,12 @@ class CosyVoiceModel:
self.flow_cache_dict = {} self.flow_cache_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.stream_context_pool = queue.Queue()
for _ in range(10):
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.is_cuda_available = torch.cuda.is_available()
def load(self, llm_model, flow_model, hift_model): def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
self.llm.to(self.device).eval() self.llm.to(self.device).eval()
@@ -84,7 +91,7 @@ class CosyVoiceModel:
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16, estimator_count=1):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!' assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model): if not os.path.exists(flow_decoder_estimator_model):
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16) convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
@@ -96,7 +103,7 @@ class CosyVoiceModel:
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
if self.flow.decoder.estimator_engine is None: if self.flow.decoder.estimator_engine is None:
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model)) raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() self.flow.decoder.estimator = EstimatorWrapper(self.flow.decoder.estimator_engine, estimator_count=estimator_count)
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context: with self.llm_context:
@@ -166,6 +173,10 @@ class CosyVoiceModel:
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread # this_uuid is used to track variables related to this inference thread
stream_context = self.stream_context_pool.get()
with stream_context:
this_uuid = str(uuid.uuid1()) this_uuid = str(uuid.uuid1())
with self.lock: with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
@@ -222,6 +233,9 @@ class CosyVoiceModel:
self.mel_overlap_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid)
self.synchronize_stream()
self.stream_context_pool.put(stream_context)
torch.cuda.empty_cache() torch.cuda.empty_cache()
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs): def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
@@ -278,6 +292,10 @@ class CosyVoiceModel:
self.hift_cache_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid)
torch.cuda.empty_cache() torch.cuda.empty_cache()
def synchronize_stream(self):
if self.is_cuda_available:
torch.cuda.current_stream().synchronize()
class CosyVoice2Model(CosyVoiceModel): class CosyVoice2Model(CosyVoiceModel):
@@ -314,11 +332,18 @@ class CosyVoice2Model(CosyVoiceModel):
self.llm_end_dict = {} self.llm_end_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.stream_context_pool = queue.Queue()
for _ in range(10):
self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.is_cuda_available = torch.cuda.is_available()
def load_jit(self, flow_encoder_model): def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
tts_mel, _ = self.flow.inference(token=token.to(self.device), tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device), prompt_token=prompt_token.to(self.device),
@@ -358,6 +383,10 @@ class CosyVoice2Model(CosyVoiceModel):
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread # this_uuid is used to track variables related to this inference thread
self.synchronize_stream()
stream_context = self.stream_context_pool.get()
with torch.cuda.stream(stream_context):
this_uuid = str(uuid.uuid1()) this_uuid = str(uuid.uuid1())
with self.lock: with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
@@ -408,6 +437,9 @@ class CosyVoice2Model(CosyVoiceModel):
with self.lock: with self.lock:
self.tts_speech_token_dict.pop(this_uuid) self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid)
self.synchronize_stream()
self.stream_context_pool.put(stream_context)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -15,7 +15,26 @@ import threading
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM from matcha.models.components.flow_matching import BASECFM
import queue
class EstimatorWrapper:
def __init__(self, estimator_engine, estimator_count=2,):
self.estimators = queue.Queue()
self.estimator_engine = estimator_engine
for _ in range(estimator_count):
estimator = estimator_engine.create_execution_context()
if estimator is not None:
self.estimators.put(estimator)
if self.estimators.empty():
raise Exception("No available estimator")
def acquire_estimator(self):
return self.estimators.get(), self.estimator_engine
def release_estimator(self, estimator):
self.estimators.put(estimator)
return
class ConditionalCFM(BASECFM): class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
@@ -124,6 +143,34 @@ class ConditionalCFM(BASECFM):
def forward_estimator(self, x, mask, mu, t, spks, cond): def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module): if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond) return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
if isinstance(self.estimator, EstimatorWrapper):
estimator, engine = self.estimator.acquire_estimator()
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('t', (2,))
estimator.set_input_shape('spks', (2, 80))
estimator.set_input_shape('cond', (2, 80, x.size(2)))
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]
for idx, data_ptr in enumerate(data_ptrs):
estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
# run trt engine
estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream)
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator)
return x
else: else:
with self.lock: with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2))) self.estimator.set_input_shape('x', (2, 80, x.size(2)))

View File

@@ -61,7 +61,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
network = builder.create_network(network_flags) network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger) parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config() config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
if fp16: if fp16:
config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile() profile = builder.create_optimization_profile()