This commit is contained in:
root
2025-10-08 18:13:09 +08:00
parent 7cbd490253
commit aceede59ba
5 changed files with 20 additions and 29 deletions

View File

@@ -103,6 +103,7 @@ class TritonPythonModel:
self.http_client = httpx.AsyncClient()
self.api_base = "http://localhost:8000/v1/chat/completions"
self.speaker_cache = {}
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
@@ -240,10 +241,12 @@ class TritonPythonModel:
"""Forward pass through the vocoder component.
Args:
prompt_speech_tokens: Prompt speech tokens tensor
prompt_speech_feat: Prompt speech feat tensor
prompt_spk_embedding: Prompt spk embedding tensor
index: Index of the request
target_speech_tokens: Target speech tokens tensor
request_id: Request ID
reference_wav: Reference waveform tensor
reference_wav_len: Reference waveform length tensor
finalize: Whether to finalize the request
Returns:
Generated waveform tensor
@@ -292,26 +295,17 @@ class TritonPythonModel:
async def _process_request(self, request):
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
# Process reference audio through audio tokenizer
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
if reference_text not in self.speaker_cache:
self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
prompt_speech_tokens = self.speaker_cache[reference_text]
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')

View File

@@ -57,10 +57,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if dtype == torch.float16:
config.set_flag(trt.BuilderFlag.FP16)
elif dtype == torch.bfloat16:
config.set_flag(trt.BuilderFlag.BF16)
elif dtype == torch.float32:
config.set_flag(trt.BuilderFlag.FP32)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
@@ -199,7 +196,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())