mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add vllm inference
This commit is contained in:
@@ -16,6 +16,7 @@ import threading
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
from cosyvoice.utils.common import set_all_random_seed
|
||||
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
@@ -32,7 +33,6 @@ class ConditionalCFM(BASECFM):
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
||||
@@ -127,26 +127,27 @@ class ConditionalCFM(BASECFM):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
||||
else:
|
||||
estimator, trt_engine = self.estimator.acquire_estimator()
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator)
|
||||
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
||||
with stream:
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator, stream)
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||
@@ -194,6 +195,7 @@ class ConditionalCFM(BASECFM):
|
||||
class CausalConditionalCFM(ConditionalCFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||
set_all_random_seed(0)
|
||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user