mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix bug
This commit is contained in:
@@ -424,7 +424,7 @@ def run_sync_streaming_inference(
|
|||||||
audios = []
|
audios = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = user_data._completed_requests.get(timeout=20)
|
result = user_data._completed_requests.get(timeout=200)
|
||||||
if isinstance(result, InferenceServerException):
|
if isinstance(result, InferenceServerException):
|
||||||
print(f"Received InferenceServerException: {result}")
|
print(f"Received InferenceServerException: {result}")
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|||||||
@@ -17,4 +17,4 @@ services:
|
|||||||
device_ids: ['0']
|
device_ids: ['0']
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
command: >
|
command: >
|
||||||
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3"
|
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3"
|
||||||
@@ -103,6 +103,7 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
self.http_client = httpx.AsyncClient()
|
self.http_client = httpx.AsyncClient()
|
||||||
self.api_base = "http://localhost:8000/v1/chat/completions"
|
self.api_base = "http://localhost:8000/v1/chat/completions"
|
||||||
|
self.speaker_cache = {}
|
||||||
|
|
||||||
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
||||||
"""Converts a tensor or list of speech token IDs to a string representation."""
|
"""Converts a tensor or list of speech token IDs to a string representation."""
|
||||||
@@ -240,10 +241,12 @@ class TritonPythonModel:
|
|||||||
"""Forward pass through the vocoder component.
|
"""Forward pass through the vocoder component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt_speech_tokens: Prompt speech tokens tensor
|
index: Index of the request
|
||||||
prompt_speech_feat: Prompt speech feat tensor
|
|
||||||
prompt_spk_embedding: Prompt spk embedding tensor
|
|
||||||
target_speech_tokens: Target speech tokens tensor
|
target_speech_tokens: Target speech tokens tensor
|
||||||
|
request_id: Request ID
|
||||||
|
reference_wav: Reference waveform tensor
|
||||||
|
reference_wav_len: Reference waveform length tensor
|
||||||
|
finalize: Whether to finalize the request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generated waveform tensor
|
Generated waveform tensor
|
||||||
@@ -292,26 +295,17 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
async def _process_request(self, request):
|
async def _process_request(self, request):
|
||||||
request_id = request.request_id()
|
request_id = request.request_id()
|
||||||
# Extract input tensors
|
|
||||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
||||||
|
|
||||||
# Process reference audio through audio tokenizer
|
|
||||||
|
|
||||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
||||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
||||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
|
||||||
|
|
||||||
wav_tensor = wav.as_numpy()
|
|
||||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
||||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
||||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
||||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
||||||
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||||
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||||
|
|
||||||
|
if reference_text not in self.speaker_cache:
|
||||||
|
self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
|
||||||
|
prompt_speech_tokens = self.speaker_cache[reference_text]
|
||||||
|
|
||||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||||
target_text = target_text[0][0].decode('utf-8')
|
target_text = target_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
|
|||||||
@@ -57,10 +57,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
|
|||||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16:
|
||||||
config.set_flag(trt.BuilderFlag.FP16)
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
elif dtype == torch.bfloat16:
|
|
||||||
config.set_flag(trt.BuilderFlag.BF16)
|
|
||||||
elif dtype == torch.float32:
|
|
||||||
config.set_flag(trt.BuilderFlag.FP32)
|
|
||||||
profile = builder.create_optimization_profile()
|
profile = builder.create_optimization_profile()
|
||||||
# load onnx model
|
# load onnx model
|
||||||
with open(onnx_model, "rb") as f:
|
with open(onnx_model, "rb") as f:
|
||||||
@@ -199,7 +196,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||||
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||||
trt_kwargs = self.get_spk_trt_kwargs()
|
trt_kwargs = self.get_spk_trt_kwargs()
|
||||||
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
|
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
with open(spk_model, 'rb') as f:
|
with open(spk_model, 'rb') as f:
|
||||||
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
|
|
||||||
echo "Step-Audio2-mini"
|
echo "Step-Audio2-mini"
|
||||||
huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
|
huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
|
||||||
cd $stepaudio2_path/token2wav
|
cd $step_audio_model_dir/token2wav
|
||||||
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
|
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||||
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
|
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
|
||||||
cd -
|
cd -
|
||||||
@@ -100,8 +100,8 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
|
echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
|
||||||
tritonserver --model-repository $model_repo --http-port 18000 &
|
|
||||||
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 &
|
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 &
|
||||||
|
tritonserver --model-repository $model_repo --http-port 18000 &
|
||||||
wait
|
wait
|
||||||
# Test using curl
|
# Test using curl
|
||||||
# curl http://localhost:8000/v1/chat/completions \
|
# curl http://localhost:8000/v1/chat/completions \
|
||||||
@@ -168,7 +168,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
# Note: Using pre-computed cosyvoice2 tokens
|
# Note: Using pre-computed cosyvoice2 tokens
|
||||||
python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
|
python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
|
||||||
# Offline Token2wav inference
|
# Offline Token2wav inference
|
||||||
# python3 token2wav_dit.py --enable-trt
|
python3 token2wav_dit.py --enable-trt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user