This commit is contained in:
lyuxiang.lx
2025-05-30 07:51:49 +00:00
parent 6dd68b9d5e
commit 9b052a94c4
8 changed files with 125 additions and 236 deletions

View File

@@ -14,7 +14,6 @@
# limitations under the License.
import os
from typing import Generator
import queue
import torch
import numpy as np
import threading
@@ -33,14 +32,12 @@ class CosyVoiceModel:
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False,
trt_concurrent: int = 1):
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
self.trt_concurrent = trt_concurrent
if self.fp16 is True:
self.llm.half()
self.flow.half()
@@ -85,7 +82,7 @@ class CosyVoiceModel:
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
@@ -94,7 +91,7 @@ class CosyVoiceModel:
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
@@ -104,7 +101,7 @@ class CosyVoiceModel:
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context, torch.cuda.amp.autocast(self.fp16):
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator):
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
for i in self.llm.inference_bistream(text=text,
@@ -246,14 +243,12 @@ class CosyVoice2Model(CosyVoiceModel):
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False,
trt_concurrent: int = 1):
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
self.trt_concurrent = trt_concurrent
if self.fp16 is True:
self.llm.half()
self.flow.half()