clean code

This commit is contained in:
root
2025-10-08 15:21:52 +08:00
parent 988d395162
commit f186ec3338
5 changed files with 266 additions and 800 deletions

View File

@@ -28,9 +28,10 @@ import json
import math
import os
import re
import threading
import time
from typing import Dict, List, Tuple, Optional, Union
import asyncio
import httpx
import numpy as np
import torch
@@ -42,11 +43,30 @@ import torchaudio
from matcha.utils.audio import mel_spectrogram
from datetime import datetime
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def parse_speech_token_string(response_text: str) -> List[int]:
"""
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
"""
speech_tokens = response_text.strip().split('><')
if len(speech_tokens) > 1:
# Add back the missing '<' and '>' for proper parsing
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
speech_ids = []
for token_str in speech_tokens:
match = re.match(r'<\|s_(\d+)\|>', token_str)
if match:
speech_ids.append(int(match.group(1)))
return speech_ids
class TritonPythonModel:
"""Triton Python model for Spark TTS.
@@ -67,6 +87,7 @@ class TritonPythonModel:
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
# self.dynamic_chunk_strategy = "equal"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
@@ -87,92 +108,86 @@ class TritonPythonModel:
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
self.http_client = httpx.AsyncClient()
def forward_llm(self, input_ids):
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
if isinstance(speech_tokens, torch.Tensor):
# Ensure tensor is on CPU and flattened
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
speech_id_str = ""
for token_id in speech_tokens:
# Convert token ID back to the speech number N
token_num = token_id - ORIGINAL_VOCAB_SIZE
speech_id_str += f"<|s_{token_num}|>"
return speech_id_str
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
"""
Prepares the response from the language model based on the provided
inputs. Creates a `pb_utils.InferenceRequest` object with passed
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
For each response from the language model:
- Checks for errors and raise an exception if any are found.
- Extracts the "output_ids" tensor from the response.
- Determines the finish reason based on the presence of the
end-of-sequence token or reaching the maximum length.
- Appends the generated token IDs to `output_ids`.
- If the finish reason is determined, decodes the output IDs to text
and prepares the final response.
The final response includes the generated text, finish reason,
completion tokens, prompt tokens, and total tokens.
Parameters
----------
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
Returns
-------
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
"""
# convert input_ids to numpy, with shape [1, sequence_length]
input_ids = input_ids.cpu().numpy()
max_tokens = 750
input_dict = {
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
"pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
"runtime_top_k": np.array([[50]], dtype=np.int32),
"temperature": np.array([[0.8]], dtype=np.float32),
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
"random_seed": np.array([[42]], dtype=np.uint64),
"input_ids": input_ids,
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
}
full_text = f"{reference_text}{target_text}"
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
# Convert inputs to Triton tensors
input_tensor_list = [
pb_utils.Tensor(k, v) for k, v in input_dict.items()
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_speech_tokens_str}
]
# Create and execute inference request
llm_request = pb_utils.InferenceRequest(
model_name="tensorrt_llm",
requested_output_names=["output_ids", "sequence_length"],
inputs=input_tensor_list,
)
payload = {
"model": "trt_engines_bfloat16",
"messages": chat,
"max_tokens": 750,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": True,
}
llm_responses = llm_request.exec(decoupled=self.decoupled)
if self.decoupled:
for llm_response in llm_responses:
if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message())
api_base = "http://localhost:8000/v1/chat/completions"
# Extract and process output
output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy()
buffer = ""
async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
line_data = line[len("data: "):].strip()
if line_data == "[DONE]":
break
try:
json_data = json.loads(line_data)
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
buffer += content
print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
# Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
except json.JSONDecodeError:
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
continue
yield actual_output_ids
else:
llm_response = llm_responses
if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message())
# Process any remaining complete tokens in the buffer after the stream ends
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
# Extract and process output
output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
@@ -225,7 +240,7 @@ class TritonPythonModel:
return prompt_spk_embedding
def forward_token2wav(
async def forward_token2wav(
self,
index: int,
target_speech_tokens: torch.Tensor,
@@ -247,17 +262,19 @@ class TritonPythonModel:
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=['waveform'],
requested_output_names=[
"waveform",
],
inputs=inputs_tensor,
request_id=request_id,
parameters={"priority": index+1},
)
inference_response = inference_request.exec()
inference_response = await inference_request.async_exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
@@ -267,14 +284,6 @@ class TritonPythonModel:
return waveform
def parse_input(self, text, prompt_text, prompt_speech_tokens):
total_text = f"{prompt_text}{text}"
prompt = self.prompt_template.format(input_text=total_text)
input_ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([input_ids], dtype=torch.int32)
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
return input_ids
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
@@ -292,106 +301,75 @@ class TritonPythonModel:
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
for generated_ids in generated_ids_iter:
generated_ids = generated_ids.tolist()
if len(generated_ids) == 0:
break
semantic_token_ids_arr.extend(generated_ids)
llm_is_done_flag[0] = True
async def _process_request(self, request):
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
def execute(self, requests):
"""Execute inference on the batched requests.
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
Args:
requests: List of inference requests
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
Returns:
List of inference responses containing generated audio
"""
responses = []
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
for request in requests:
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
# prompt_speech_feat = None
# prompt_spk_embedding = None
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
else:
# using pre-cached reference text
assert False, "using pre-cached reference text is not supported"
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
print(f"target_text: {target_text}, time: {datetime.now()}")
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
if self.decoupled:
response_sender = request.get_response_sender()
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
# prompt_speech_feat = None
# prompt_spk_embedding = None
else:
assert False, "wav is None"
# using pre-cached reference text
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
# Prepare prompt for LLM
input_ids = self.parse_input(
text=target_text,
prompt_text=reference_text,
semantic_token_ids_arr = []
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
async for generated_ids in self.forward_llm_async(
target_text=target_text,
reference_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
)
# Generate semantic tokens with LLM
generated_ids_iter = self.forward_llm(input_ids)
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
llm_is_done_flag = [False]
llm_thread = threading.Thread(
target=self._llm_gen_thread,
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
)
llm_thread.start()
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
):
if not generated_ids:
break
semantic_token_ids_arr.append(generated_ids)
print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if llm_is_done_flag[0]:
break
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
sub_tts_speech = await self.forward_token2wav(
chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False
)
print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
@@ -401,6 +379,8 @@ class TritonPythonModel:
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "equal":
this_token_hop_len = self.token_hop_len
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
@@ -420,19 +400,36 @@ class TritonPythonModel:
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
time.sleep(0.02)
break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
raise NotImplementedError("Decoupled mode is not supported")
llm_thread.join()
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
raise NotImplementedError("Decoupled mode is not supported")
async def execute(self, requests):
"""Execute inference on the batched requests.
if not self.decoupled:
return responses
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
tasks = [
asyncio.create_task(self._process_request(request))
for request in requests
]
await asyncio.gather(*tasks)
return None
def finalize(self):
self.logger.log_info("Finalizing CosyVoice DIT model")
if hasattr(self, "http_client"):
asyncio.run(self.http_client.aclose())

View File

@@ -1,435 +0,0 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import math
import os
import re
import time
from typing import Dict, List, Tuple, Optional, Union
import asyncio
import httpx
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio
from matcha.utils.audio import mel_spectrogram
from datetime import datetime
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def parse_speech_token_string(response_text: str) -> List[int]:
"""
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
"""
speech_tokens = response_text.strip().split('><')
if len(speech_tokens) > 1:
# Add back the missing '<' and '>' for proper parsing
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
speech_ids = []
for token_str in speech_tokens:
match = re.match(r'<\|s_(\d+)\|>', token_str)
if match:
speech_ids.append(int(match.group(1)))
return speech_ids
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
# self.dynamic_chunk_strategy = "equal"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
self.http_client = httpx.AsyncClient()
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
if isinstance(speech_tokens, torch.Tensor):
# Ensure tensor is on CPU and flattened
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
speech_id_str = ""
for token_id in speech_tokens:
# Convert token ID back to the speech number N
token_num = token_id - ORIGINAL_VOCAB_SIZE
speech_id_str += f"<|s_{token_num}|>"
return speech_id_str
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
"""
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
"""
full_text = f"{reference_text}{target_text}"
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_speech_tokens_str}
]
payload = {
"model": "trt_engines_bfloat16",
"messages": chat,
"max_tokens": 750,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": True,
}
api_base = "http://localhost:8000/v1/chat/completions"
buffer = ""
async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
line_data = line[len("data: "):].strip()
if line_data == "[DONE]":
break
try:
json_data = json.loads(line_data)
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
buffer += content
print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
except json.JSONDecodeError:
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
continue
# Process any remaining complete tokens in the buffer after the stream ends
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
async def forward_token2wav(
self,
index: int,
target_speech_tokens: torch.Tensor,
request_id: str,
reference_wav: object,
reference_wav_len: object,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
prompt_speech_tokens: Prompt speech tokens tensor
prompt_speech_feat: Prompt speech feat tensor
prompt_spk_embedding: Prompt spk embedding tensor
target_speech_tokens: Target speech tokens tensor
Returns:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=[
"waveform",
],
inputs=inputs_tensor,
request_id=request_id,
parameters={"priority": index+1},
)
inference_response = await inference_request.async_exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
async def _process_request(self, request):
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
# prompt_speech_feat = None
# prompt_spk_embedding = None
else:
# using pre-cached reference text
assert False, "using pre-cached reference text is not supported"
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
print(f"target_text: {target_text}, time: {datetime.now()}")
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
async for generated_ids in self.forward_llm_async(
target_text=target_text,
reference_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
):
if not generated_ids:
break
semantic_token_ids_arr.append(generated_ids)
print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
sub_tts_speech = await self.forward_token2wav(
chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False
)
print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "equal":
this_token_hop_len = self.token_hop_len
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
self.logger.log_info(f"multiples: {multiples}")
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
raise NotImplementedError("Decoupled mode is not supported")
async def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
tasks = [
asyncio.create_task(self._process_request(request))
for request in requests
]
await asyncio.gather(*tasks)
return None
def finalize(self):
self.logger.log_info("Finalizing CosyVoice DIT model")
if hasattr(self, "http_client"):
asyncio.run(self.http_client.aclose())

View File

@@ -47,8 +47,6 @@ import requests
import asyncio
import httpx
from token2wav import CosyVoice2_Token2Wav
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
@@ -367,7 +365,12 @@ def main(args):
runner = None
else:
raise ValueError(f"Unsupported backend: {args.backend}")
if 'Step-Audio-2-mini' in args.token2wav_path:
from token2wav_dit import CosyVoice2_Token2Wav
else:
assert 'CosyVoice2-0.5B' in args.token2wav_path
from token2wav import CosyVoice2_Token2Wav
token2wav_model = CosyVoice2_Token2Wav(
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
)
@@ -589,7 +592,6 @@ def main(args):
t2w_prompt_audios_list,
t2w_prompt_audios_sample_rate,
)
torch.cuda.synchronize()
token2wav_end_time = time.time()
total_token2wav_time += (token2wav_end_time - token2wav_start_time)

View File

@@ -1,28 +1,33 @@
#!/bin/bash
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
export CUDA_VISIBLE_DEVICES=0
cosyvoice_path=/workspace/CosyVoice
# cosyvoice_path=/workspace/CosyVoice
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
stage=$1
stop_stage=$2
N_GPUS=2 # set the number of GPUs to use
huggingface_model_local_dir=./cosyvoice2_llm
model_scope_model_local_dir=./CosyVoice2-0.5B
step_audio_model_dir=./Step-Audio-2-mini
trt_dtype=bfloat16
trt_weights_dir=./trt_weights_${trt_dtype}
trt_engines_dir=./trt_engines_${trt_dtype}
model_repo=./model_repo_cosyvoice2_dit
use_spk2info_cache=False
bls_instance_num=4
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
echo "Cloning Step-Audio2-mini"
git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
echo "Cloning CosyVoice"
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
cd $cosyvoice_path
@@ -35,8 +40,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
echo "Step-Audio2-mini"
huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
cd $stepaudio2_path/token2wav
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
cd -
fi
@@ -60,40 +70,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
--engine_dir=$trt_engines_dir || exit 1
fi
# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# echo "Creating model repository"
# rm -rf $model_repo
# mkdir -p $model_repo
# cosyvoice2_dir="cosyvoice2_dit"
# token2wav_dir="token2wav_dit"
# cp -r ./model_repo/${cosyvoice2_dir} $model_repo
# cp -r ./model_repo/tensorrt_llm $model_repo
# cp -r ./model_repo/${token2wav_dir} $model_repo
# #if [ $use_spk2info_cache == "False" ]; then
# cp -r ./model_repo/audio_tokenizer $model_repo
# cp -r ./model_repo/speaker_embedding $model_repo
# #fi
# ENGINE_PATH=$trt_engines_dir
# MAX_QUEUE_DELAY_MICROSECONDS=0
# MODEL_DIR=$model_scope_model_local_dir
# LLM_TOKENIZER_DIR=$huggingface_model_local_dir
# BLS_INSTANCE_NUM=1
# TRITON_MAX_BATCH_SIZE=16
# DECOUPLED_MODE=True # True for streaming, False for offline
# STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
# python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
# python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
# python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
# #if [ $use_spk2info_cache == "False" ]; then
# python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
# python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
# #fi
# fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Creating model repository async mode"
rm -rf $model_repo
@@ -102,122 +78,75 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
token2wav_dir="token2wav_dit"
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
cp -r ./model_repo/tensorrt_llm $model_repo
cp -r ./model_repo/${token2wav_dir} $model_repo
#if [ $use_spk2info_cache == "False" ]; then
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
#fi
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
ENGINE_PATH=$trt_engines_dir
MAX_QUEUE_DELAY_MICROSECONDS=0
MODEL_DIR=$model_scope_model_local_dir
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
BLS_INSTANCE_NUM=4
BLS_INSTANCE_NUM=$bls_instance_num
TRITON_MAX_BATCH_SIZE=1
DECOUPLED_MODE=True # True for streaming, False for offline
STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
DECOUPLED_MODE=True
STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
#if [ $use_spk2info_cache == "False" ]; then
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
#fi
rm -rf $model_repo/tensorrt_llm
# mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Starting Triton server on $N_GPUS GPUs"
for i in $(seq 0 $(($N_GPUS - 1))); do
echo "Starting server on GPU $i"
http_port=$((19000 + $i))
grpc_port=$((18000 + $i))
metrics_port=$((17000 + $i))
CUDA_VISIBLE_DEVICES=$i tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
done
echo "Servers are running in the background. Press Ctrl+C to stop them and the script."
wait
fi
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
echo "Starting Triton server on $N_GPUS GPUs"
N_GPUS=1
for i in $(seq 0 $(($N_GPUS - 1))); do
echo "Starting server on GPU $i"
http_port=$((19000 + $i))
grpc_port=$((18000 + $i))
metrics_port=$((17000 + $i))
CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
done
echo "Servers are running in the background. Press Ctrl+C to stop them and the script."
echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
tritonserver --model-repository $model_repo --http-port 18000 &
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 &
wait
# Test using curl
# curl http://localhost:8000/v1/chat/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "trt_engines_bfloat16",
# "messages":[{"role": "user", "content": "Where is New York?"},
# {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
# "max_tokens": 512,
# "temperature": 0.8,
# "top_p": 0.95,
# "top_k": 50,
# "stop": ["<|eos1|>"],
# "repetition_penalty": 1.2,
# "stream": false
# }'
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Single request test http, only work for offline TTS mode"
python3 client_http.py \
--reference-audio ./assets/prompt_audio.wav \
--reference-text "吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。" \
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
--model-name cosyvoice2
echo "Running benchmark client"
num_task=4
mode=streaming
BLS_INSTANCE_NUM=$bls_instance_num
python3 client_grpc.py \
--server-addr localhost \
--server-port 8001 \
--model-name cosyvoice2_dit \
--num-tasks $num_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Running benchmark client grpc on $N_GPUS GPUs"
num_task=1
echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
mode=streaming
BLS_INSTANCE_NUM=4
for i in $(seq 0 $(($N_GPUS - 1))); do
grpc_port=$((18000 + $i))
echo "Running client for server on localhost:$grpc_port"
python3 client_grpc.py \
--server-addr localhost \
--server-port $grpc_port \
--model-name cosyvoice2_dit \
--num-tasks $num_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_gpu${i} &
done
wait
fi
if [ $stage -le 50 ] && [ $stop_stage -ge 50 ]; then
echo "Running benchmark client grpc on $N_GPUS GPUs"
num_task=4
N_GPUS=1
mode=streaming
BLS_INSTANCE_NUM=4
for i in $(seq 0 $(($N_GPUS - 1))); do
grpc_port=$((18000 + $i))
echo "Running client for server on localhost:$grpc_port"
python3 client_grpc.py \
--server-addr localhost \
--server-port $grpc_port \
--model-name cosyvoice2_dit \
--num-tasks $num_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_single_card_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} &
done
wait
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "stage 6: Offline inference benchmark"
n_gpus=1
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
backend=trtllm-serve # hf, trtllm, vllm
backend=trtllm # hf, trtllm, vllm, trtllm-serve
batch_sizes=(16 8 4 2 1)
batch_sizes=(16 8 4 2)
batch_sizes=(16)
token2wav_batch_size=1
for batch_size in ${batch_sizes[@]}; do
for dataset in ${datasets[@]}; do
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
@@ -225,7 +154,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $model_scope_model_local_dir \
--token2wav-path $step_audio_model_dir/token2wav \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
@@ -234,34 +163,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
export CUDA_VISIBLE_DEVICES=1
# Note: Using pre-computed cosyvoice2 tokens
python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
# Offline Token2wav inference
# python3 token2wav_dit.py --enable-trt
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
#! /usr/bin/env bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "trt_engines_bfloat16",
"messages":[{"role": "user", "content": "Where is New York?"},
{"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
"max_tokens": 512,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"stop": ["<|eos1|>"],
"repetition_penalty": 1.2,
"stream": false
}'
fi

View File

@@ -54,7 +54,7 @@ if __name__ == "__main__":
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
flow_pre_lookahead_len = 3
CHUNK_SIZE = 15
CHUNK_SIZE = 25
token_frame_rate = 25
OVERLAP_SIZE = 0
@@ -67,20 +67,12 @@ if __name__ == "__main__":
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
# if id != "unseen3_text5":
# continue
# else:
# a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt")
# generated_speech_tokens = a["semantic_token_ids_arr"]
# print(generated_speech_tokens)
assert prompt_audio_sample_rate == 16000
prompt_text = prompt_text_list[0]
prompt_speech_tokens = prompt_speech_tokens_list[0]
# generated_ids_iter = fake_generated_id_iter(generated_speech_tokens)
semantic_token_ids_arr, token_offset = [], 0
flow_prompt_speech_token_len = len(prompt_speech_tokens)
@@ -114,14 +106,16 @@ if __name__ == "__main__":
audios = output_wavs
reconstructed_audio = np.concatenate(audios)
# Save reconstructed audio
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
print(f"Saved {id}")
end_time = time.time()
if _ == 0:
token2wav_model.speaker_cache = {}
print(f"Warmup time: {end_time - start_time} seconds")
print(f"Total forward count: {total_forward_count}")
print(f"Warmup time: {end_time - start_time} seconds")
print("clear speaker cache")
elif _ == 1:
print(f"Cost time without speaker cache: {end_time - start_time} seconds")
else:
print(f"Cost time with speaker cache: {end_time - start_time} seconds")
print(f"Total flow matching forward calls: {total_forward_count}")