From aceede59ba5fe97b65a4cb7f36b76f19de29b4f9 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Oct 2025 18:13:09 +0800 Subject: [PATCH] fix bug --- runtime/triton_trtllm/client_grpc.py | 2 +- runtime/triton_trtllm/docker-compose.dit.yml | 2 +- .../model_repo/cosyvoice2_dit/1/model.py | 32 ++++++++----------- .../token2wav_dit/1/token2wav_dit.py | 7 ++-- .../run_stepaudio2_dit_token2wav.sh | 6 ++-- 5 files changed, 20 insertions(+), 29 deletions(-) diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 718fe86..840390d 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -424,7 +424,7 @@ def run_sync_streaming_inference( audios = [] while True: try: - result = user_data._completed_requests.get(timeout=20) + result = user_data._completed_requests.get(timeout=200) if isinstance(result, InferenceServerException): print(f"Received InferenceServerException: {result}") return None, None, None, None diff --git a/runtime/triton_trtllm/docker-compose.dit.yml b/runtime/triton_trtllm/docker-compose.dit.yml index 1f97f7c..35312a1 100644 --- a/runtime/triton_trtllm/docker-compose.dit.yml +++ b/runtime/triton_trtllm/docker-compose.dit.yml @@ -17,4 +17,4 @@ services: device_ids: ['0'] capabilities: [gpu] command: > - /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3" \ No newline at end of file + /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3" \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 827925c..523a5b8 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -103,6 +103,7 @@ class TritonPythonModel: self.http_client = httpx.AsyncClient() self.api_base = "http://localhost:8000/v1/chat/completions" + self.speaker_cache = {} def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" @@ -240,10 +241,12 @@ class TritonPythonModel: """Forward pass through the vocoder component. Args: - prompt_speech_tokens: Prompt speech tokens tensor - prompt_speech_feat: Prompt speech feat tensor - prompt_spk_embedding: Prompt spk embedding tensor + index: Index of the request target_speech_tokens: Target speech tokens tensor + request_id: Request ID + reference_wav: Reference waveform tensor + reference_wav_len: Reference waveform length tensor + finalize: Whether to finalize the request Returns: Generated waveform tensor @@ -292,26 +295,17 @@ class TritonPythonModel: async def _process_request(self, request): request_id = request.request_id() - # Extract input tensors - wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") - - # Process reference audio through audio tokenizer - - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + + if reference_text not in self.speaker_cache: + self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0) + prompt_speech_tokens = self.speaker_cache[reference_text] + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index bda4cb1..3d50325 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -57,10 +57,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB if dtype == torch.float16: config.set_flag(trt.BuilderFlag.FP16) - elif dtype == torch.bfloat16: - config.set_flag(trt.BuilderFlag.BF16) - elif dtype == torch.float32: - config.set_flag(trt.BuilderFlag.FP32) + profile = builder.create_optimization_profile() # load onnx model with open(onnx_model, "rb") as f: @@ -199,7 +196,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: trt_kwargs = self.get_spk_trt_kwargs() - convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32) import tensorrt as trt with open(spk_model, 'rb') as f: spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index c401793..5881b44 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -42,7 +42,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then echo "Step-Audio2-mini" huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini - cd $stepaudio2_path/token2wav + cd $step_audio_model_dir/token2wav wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx cd - @@ -100,8 +100,8 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve" - tritonserver --model-repository $model_repo --http-port 18000 & mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 & + tritonserver --model-repository $model_repo --http-port 18000 & wait # Test using curl # curl http://localhost:8000/v1/chat/completions \ @@ -168,7 +168,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then # Note: Using pre-computed cosyvoice2 tokens python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential # Offline Token2wav inference - # python3 token2wav_dit.py --enable-trt + python3 token2wav_dit.py --enable-trt fi