mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
Enhance CosyVoice with CUDA stream management and estimator handling
- Introduced a queue-based system for managing CUDA streams to improve inference performance. - Updated inference methods to utilize CUDA streams for asynchronous processing. - Added an EstimatorWrapper class to manage TensorRT estimators, allowing for efficient execution context handling. - Modified model loading functions to support estimator count configuration. - Improved logging and performance tracking during inference operations.
This commit is contained in:
@@ -15,7 +15,26 @@ import threading
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
import queue
|
||||
|
||||
class EstimatorWrapper:
|
||||
def __init__(self, estimator_engine, estimator_count=2,):
|
||||
self.estimators = queue.Queue()
|
||||
self.estimator_engine = estimator_engine
|
||||
for _ in range(estimator_count):
|
||||
estimator = estimator_engine.create_execution_context()
|
||||
if estimator is not None:
|
||||
self.estimators.put(estimator)
|
||||
|
||||
if self.estimators.empty():
|
||||
raise Exception("No available estimator")
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.estimators.get(), self.estimator_engine
|
||||
|
||||
def release_estimator(self, estimator):
|
||||
self.estimators.put(estimator)
|
||||
return
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
@@ -125,22 +144,50 @@ class ConditionalCFM(BASECFM):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
if isinstance(self.estimator, EstimatorWrapper):
|
||||
estimator, engine = self.estimator.acquire_estimator()
|
||||
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
|
||||
for idx, data_ptr in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
|
||||
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x
|
||||
estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator)
|
||||
return x
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Reference in New Issue
Block a user