add streaming dit

This commit is contained in:
yuekaiz
2025-09-24 15:18:01 +08:00
parent 444b7ff5df
commit 482464ea27
10 changed files with 850 additions and 269 deletions

View File

@@ -209,7 +209,8 @@ def get_args():
choices=[ choices=[
"f5_tts", "f5_tts",
"spark_tts", "spark_tts",
"cosyvoice2"], "cosyvoice2",
"cosyvoice2_dit"],
help="triton model_repo module name to request", help="triton model_repo module name to request",
) )
@@ -260,8 +261,8 @@ def get_args():
parser.add_argument( parser.add_argument(
"--use-spk2info-cache", "--use-spk2info-cache",
type=bool, type=str,
default=False, default="False",
help="Use spk2info cache for reference audio.", help="Use spk2info cache for reference audio.",
) )
@@ -490,6 +491,7 @@ async def send_streaming(
padding_duration=padding_duration, padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache use_spk2info_cache=use_spk2info_cache
) )
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
user_data = UserData() user_data = UserData()
@@ -670,11 +672,15 @@ async def main():
trust_remote_code=True, trust_remote_code=True,
) )
manifest_item_list = [] manifest_item_list = []
tmp_audio_path="./asset_zero_shot_prompt.wav"
tmp_audio_text="希望你以后能够做的比我还好呦。"
for i in range(len(dataset)): for i in range(len(dataset)):
manifest_item_list.append( manifest_item_list.append(
{ {
"audio_filepath": dataset[i]["prompt_audio"], "audio_filepath": dataset[i]["prompt_audio"],
"reference_text": dataset[i]["prompt_text"], "reference_text": dataset[i]["prompt_text"],
# "audio_filepath": tmp_audio_path,
# "reference_text": tmp_audio_text,
"target_audio_path": dataset[i]["id"], "target_audio_path": dataset[i]["id"],
"target_text": dataset[i]["target_text"], "target_text": dataset[i]["target_text"],
} }
@@ -686,7 +692,7 @@ async def main():
manifest_item_list = split_data(manifest_item_list, num_tasks) manifest_item_list = split_data(manifest_item_list, num_tasks)
os.makedirs(args.log_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True)
args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
tasks = [] tasks = []
start_time = time.time() start_time = time.time()
for i in range(num_tasks): for i in range(num_tasks):

View File

@@ -227,12 +227,11 @@ class TritonPythonModel:
def forward_token2wav( def forward_token2wav(
self, self,
index: int,
target_speech_tokens: torch.Tensor, target_speech_tokens: torch.Tensor,
request_id: str, request_id: str,
prompt_speech_tokens: torch.Tensor = None, reference_wav: object,
prompt_speech_feat: torch.Tensor = None, reference_wav_len: object,
prompt_spk_embedding: torch.Tensor = None,
token_offset: int = None,
finalize: bool = None) -> torch.Tensor: finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component. """Forward pass through the vocoder component.
@@ -246,29 +245,16 @@ class TritonPythonModel:
Generated waveform tensor Generated waveform tensor
""" """
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor] inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
if token_offset is not None:
assert finalize is not None
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor.append(token_offset_tensor)
inputs_tensor.append(finalize_tensor)
if prompt_spk_embedding is not None:
assert prompt_speech_feat is not None
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
# Create and execute inference request # Create and execute inference request
inference_request = pb_utils.InferenceRequest( inference_request = pb_utils.InferenceRequest(
model_name='token2wav', model_name='token2wav_dit',
requested_output_names=['waveform'], requested_output_names=['waveform'],
inputs=inputs_tensor, inputs=inputs_tensor,
request_id=request_id, request_id=request_id,
parameters={"priority": index+1},
) )
inference_response = inference_request.exec() inference_response = inference_request.exec()
@@ -346,8 +332,15 @@ class TritonPythonModel:
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8') reference_text = reference_text[0][0].decode('utf-8')
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
# prompt_speech_feat = None
# prompt_spk_embedding = None
else: else:
assert False, "wav is None"
# using pre-cached reference text # using pre-cached reference text
reference_text = self.default_spk_info["prompt_text"] reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
@@ -391,12 +384,12 @@ class TritonPythonModel:
break break
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav( sub_tts_speech = self.forward_token2wav(
this_tts_speech_token, request_id, prompt_speech_tokens, chunk_index,
prompt_speech_feat, prompt_spk_embedding, token_offset, False this_tts_speech_token, request_id, wav, wav_len, False
) )
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
@@ -429,8 +422,8 @@ class TritonPythonModel:
else: else:
time.sleep(0.02) time.sleep(0.02)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response) response_sender.send(inference_response)
@@ -439,17 +432,7 @@ class TritonPythonModel:
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end") self.logger.log_info("send tritonserver_response_complete_final to end")
else: else:
generated_ids = next(generated_ids_iter) raise NotImplementedError("Decoupled mode is not supported")
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
responses.append(inference_response)
if not self.decoupled: if not self.decoupled:
return responses return responses

View File

@@ -0,0 +1,438 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import math
import os
import re
import time
from typing import Dict, List, Tuple, Optional, Union
import asyncio
import httpx
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def parse_speech_token_string(response_text: str) -> List[int]:
"""
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
"""
speech_tokens = response_text.strip().split('><')
if len(speech_tokens) > 1:
# Add back the missing '<' and '>' for proper parsing
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
speech_ids = []
for token_str in speech_tokens:
match = re.match(r'<\|s_(\d+)\|>', token_str)
if match:
speech_ids.append(int(match.group(1)))
return speech_ids
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
# self.default_spk_info = spk_info["001"]
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
if isinstance(speech_tokens, torch.Tensor):
# Ensure tensor is on CPU and flattened
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
speech_id_str = ""
for token_id in speech_tokens:
# Convert token ID back to the speech number N
token_num = token_id - ORIGINAL_VOCAB_SIZE
speech_id_str += f"<|s_{token_num}|>"
return speech_id_str
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
"""
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
"""
full_text = f"{reference_text}{target_text}"
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_speech_tokens_str}
]
print(chat)
payload = {
"model": "trt_engines_bfloat16",
"messages": chat,
"max_tokens": 750,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": True,
}
api_base = "http://localhost:8000/v1/chat/completions"
buffer = ""
async with httpx.AsyncClient() as client:
async with client.stream("POST", api_base, json=payload, timeout=None) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
line_data = line[len("data: "):].strip()
if line_data == "[DONE]":
break
try:
json_data = json.loads(line_data)
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
buffer += content
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
except json.JSONDecodeError:
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
continue
# Process any remaining complete tokens in the buffer after the stream ends
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
def forward_token2wav(
self,
index: int,
target_speech_tokens: torch.Tensor,
request_id: str,
reference_wav: object,
reference_wav_len: object,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
prompt_speech_tokens: Prompt speech tokens tensor
prompt_speech_feat: Prompt speech feat tensor
prompt_spk_embedding: Prompt spk embedding tensor
target_speech_tokens: Target speech tokens tensor
Returns:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=['waveform'],
inputs=inputs_tensor,
request_id=request_id,
parameters={"priority": index+1},
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
async def _process_request(self, request):
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
# prompt_speech_feat = None
# prompt_spk_embedding = None
else:
# using pre-cached reference text
assert False, "using pre-cached reference text is not supported"
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
async for generated_ids in self.forward_llm_async(
target_text=target_text,
reference_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
):
if not generated_ids:
break
semantic_token_ids_arr.append(generated_ids)
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
self.logger.log_info(f"multiples: {multiples}")
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
## debug
## save semantic_token_ids_arr and reference_text, target_text to a single json file
# save into a torch .pt
# for i, item in enumerate(semantic_token_ids_arr):
# semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE
# import json
# data = {
# "semantic_token_ids_arr": semantic_token_ids_arr,
# "reference_text": reference_text,
# "target_text": target_text
# }
# with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f:
# torch.save(data, f)
# with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f:
# json.dump(data, f)
# ##
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
raise NotImplementedError("Decoupled mode is not supported")
async def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
tasks = [
asyncio.create_task(self._process_request(request))
for request in requests
]
await asyncio.gather(*tasks)
return None

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
name: "cosyvoice2" name: "cosyvoice2_dit"
backend: "python" backend: "python"
max_batch_size: ${triton_max_batch_size} max_batch_size: ${triton_max_batch_size}
dynamic_batching { dynamic_batching {

View File

@@ -42,6 +42,8 @@ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vl
from cosyvoice.utils.common import TrtContextWrapper from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
from .token2wav_dit import CosyVoice2_Token2Wav
import hashlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,117 +51,19 @@ logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663 ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1) torch.set_num_threads(1)
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
"""
Generates a unique ID for a torch.Tensor.
Tensors with the same elements and properties will have the same ID.
"""
# Convert tensor to a byte string
tensor_bytes = tensor.numpy().tobytes()
class CosyVoice2: # Create a SHA-256 hash of the byte string
hasher = hashlib.sha256()
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'): hasher.update(tensor_bytes)
self.model_dir = model_dir return hasher.hexdigest()
self.fp16 = fp16
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
class CosyVoice2Model:
def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False,
device: str = 'cuda'):
self.device = device
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.flow.half()
# streaming tts config
self.token_hop_len = 25
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
self.speech_window = np.hamming(2 * self.source_cache_len)
self.hift_cache_dict = defaultdict(lambda: None)
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load(self, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
class TritonPythonModel: class TritonPythonModel:
"""Triton Python model for vocoder. """Triton Python model for vocoder.
@@ -183,16 +87,10 @@ class TritonPythonModel:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}") logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2( # FIXME: device id settings
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device self.token2wav_model = CosyVoice2_Token2Wav(
model_dir, enable_trt=True, streaming=True
) )
spk_info_path = os.path.join(model_dir, "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_dir}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
logger.info("Token2Wav initialized successfully") logger.info("Token2Wav initialized successfully")
def execute(self, requests): def execute(self, requests):
@@ -208,66 +106,31 @@ class TritonPythonModel:
# Process each request in batch # Process each request in batch
for request in requests: for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device) target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
if prompt_speech_tokens_tensor is not None:
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
else:
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
# shift the speech tokens according to the original vocab size # shift the speech tokens according to the original vocab size
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens.squeeze().tolist()
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
if token_offset is not None: finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
token_offset = token_offset.as_numpy().item()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() request_id = request.request_id()
if not finalize:
stream = True
else:
stream = False
request_id = request.request_id()
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
prompt_token=prompt_speech_tokens,
prompt_feat=prompt_speech_feat,
embedding=prompt_spk_embedding,
token_offset=token_offset,
uuid=request_id,
stream=stream,
finalize=finalize)
if finalize:
self.token2wav_model.model.hift_cache_dict.pop(request_id)
else: wav_array = pb_utils.get_input_tensor_by_name(
tts_mel, _ = self.token2wav_model.model.flow.inference( request, "reference_wav").as_numpy()
token=target_speech_tokens, wav_len = pb_utils.get_input_tensor_by_name(
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( request, "reference_wav_len").as_numpy().item()
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
audio_hat, _ = self.token2wav_model.model.hift.inference( wav_array = torch.from_numpy(wav_array)
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) # Prepare inputs
) wav = wav_array[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav)
# wav = wav.to(self.device)
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
generated_wave = audio_hat.squeeze(0).cpu().numpy() generated_wave = audio_hat.squeeze(0).cpu().numpy()

View File

@@ -362,17 +362,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
spk_emb_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
n_timesteps=10 n_timesteps=10
) )
new_cache = {k: v.clone() for k, v in cache.items()}
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
cache['estimator_att_cache'] = cache['estimator_att_cache'].clone() return new_cache
cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone()
return cache
@torch.inference_mode() @torch.inference_mode()
def forward_streaming( def forward_streaming(
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
): ):
if speaker_id not in self.speaker_cache: if speaker_id not in self.speaker_cache:
# if 1:
assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio is not None, "prompt_audio is required for new speaker"
assert prompt_audio_sample_rate == 16000 assert prompt_audio_sample_rate == 16000
@@ -382,10 +382,21 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
# if speaker_id not in self.speaker_cache:
# if 1:
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
print(f"speaker_id {speaker_id} added to cache")
# get a clone of cache dict ['estimator_att_cache'] and later check if it would be change
att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone()
cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone()
conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone()
conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone()
if request_id not in self.streaming_flow_cache: if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
@@ -409,6 +420,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
n_timesteps=10, n_timesteps=10,
) )
# get the original att_cache
original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache']
original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache']
original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache']
original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache']
if not torch.allclose(original_att_cache, att_cache_clone):
print("att_cache changed")
# print the last 10 elements of original_att_cache and att_cache_clone
print(original_att_cache[:, :, :, -10:])
print(att_cache_clone[:, :, :, -10:])
breakpoint()
if not torch.allclose(original_cnn_cache, cnn_cache_clone):
print("cnn_cache changed")
print(original_cnn_cache[..., -10:])
print(cnn_cache_clone[..., -10:])
breakpoint()
if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone):
print("conformer_cnn_cache changed")
print(original_conformer_cnn_cache[..., -10:])
print(conformer_cnn_cache_clone[..., -10:])
breakpoint()
if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone):
print("conformer_att_cache changed")
print(original_conformer_att_cache[..., -10:])
print(conformer_att_cache_clone[..., -10:])
breakpoint()
self.streaming_flow_cache[request_id] = new_streaming_flow_cache self.streaming_flow_cache[request_id] = new_streaming_flow_cache
@@ -420,10 +458,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
hift_cache_mel = self.hift_cache_dict[request_id]['mel'] hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
hift_cache_source = self.hift_cache_dict[request_id]['source'] hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
hift_cache_speech = self.hift_cache_dict[request_id]['speech'] hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
speech, source = self.hift(mel, hift_cache_source) speech, source = self.hift(mel, hift_cache_source)
@@ -444,7 +482,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
assert request_id in self.streaming_flow_cache assert request_id in self.streaming_flow_cache
self.streaming_flow_cache.pop(request_id) self.streaming_flow_cache.pop(request_id)
self.hift_cache_dict.pop(request_id) self.hift_cache_dict.pop(request_id)
# breakpoint()
return speech return speech
def collate_fn(batch): def collate_fn(batch):

View File

@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
name: "token2wav" name: "token2wav_dit"
backend: "python" backend: "python"
max_batch_size: ${triton_max_batch_size} max_batch_size: ${triton_max_batch_size}
dynamic_batching { dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds} max_queue_delay_microseconds: ${max_queue_delay_microseconds}
priority_levels: 10
default_priority_level: 10
} }
parameters [ parameters [
{ {
@@ -32,29 +34,14 @@ input [
dims: [-1] dims: [-1]
}, },
{ {
name: "prompt_speech_tokens" name: "reference_wav"
data_type: TYPE_INT32 data_type: TYPE_FP32
dims: [-1] dims: [-1]
optional: true
}, },
{ {
name: "prompt_speech_feat" name: "reference_wav_len"
data_type: TYPE_FP16
dims: [-1, 80]
optional: true
},
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
optional: true
},
{
name: "token_offset"
data_type: TYPE_INT32 data_type: TYPE_INT32
dims: [ 1 ] dims: [1]
reshape: { shape: [ ] }
optional: true
}, },
{ {
name: "finalize" name: "finalize"

View File

@@ -43,6 +43,9 @@ import soundfile as sf
import s3tokenizer import s3tokenizer
from functools import partial from functools import partial
import time import time
import requests
import asyncio
import httpx
from token2wav import CosyVoice2_Token2Wav from token2wav import CosyVoice2_Token2Wav
@@ -53,6 +56,32 @@ except RuntimeError:
pass pass
async def send_request_async(client, url, payload):
response = await client.post(url, json=payload, timeout=None)
response.raise_for_status()
response_json = response.json()
return response_json['choices'][0]['message']['content']
async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
async with httpx.AsyncClient() as client:
tasks = []
for chat in chats:
payload = {
"model": model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
tasks.append(send_request_async(client, api_base, payload))
return await asyncio.gather(*tasks)
def extract_speech_ids(speech_tokens_str): def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>""" """Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = [] speech_ids = []
@@ -149,7 +178,7 @@ def get_args():
"--backend", "--backend",
type=str, type=str,
default="hf", default="hf",
choices=["hf", "trtllm", "vllm"], choices=["hf", "trtllm", "vllm", "trtllm-serve"],
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM", help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
) )
parser.add_argument( parser.add_argument(
@@ -164,6 +193,18 @@ def get_args():
default=0.6, default=0.6,
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)", help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
) )
parser.add_argument(
"--openai-api-base",
type=str,
default="http://localhost:8000/v1/chat/completions",
help="OpenAI API base URL (for trtllm-serve backend)",
)
parser.add_argument(
"--openai-model-name",
type=str,
default="trt_engines_bfloat16",
help="Model name to use with OpenAI API (for trtllm-serve backend)",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -180,6 +221,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
prompt_text_after_apply_template_list = [] prompt_text_after_apply_template_list = []
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], [] mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
chat_list = []
for _, item in enumerate(batch): for _, item in enumerate(batch):
audio_processing_start_time = time.time() audio_processing_start_time = time.time()
prompt_text, target_text = ( prompt_text, target_text = (
@@ -237,6 +279,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
{"role": "user", "content": full_text_list[i]}, {"role": "user", "content": full_text_list[i]},
{"role": "assistant", "content": prompt_audio_cosy2_id_str} {"role": "assistant", "content": prompt_audio_cosy2_id_str}
] ]
chat_list.append(chat)
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template" assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
@@ -265,6 +308,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
"audio_processing_time": total_audio_processing_time, "audio_processing_time": total_audio_processing_time,
"speech_tokenization_time": total_speech_tokenization_time, "speech_tokenization_time": total_speech_tokenization_time,
"text_tokenization_time": total_text_tokenization_time, "text_tokenization_time": total_text_tokenization_time,
"chat_list": chat_list
} }
@@ -318,6 +362,9 @@ def main(args):
elif args.backend == "vllm": elif args.backend == "vllm":
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
runner = None runner = None
elif args.backend == "trtllm-serve":
model = None
runner = None
else: else:
raise ValueError(f"Unsupported backend: {args.backend}") raise ValueError(f"Unsupported backend: {args.backend}")
@@ -452,6 +499,35 @@ def main(args):
print(outputs) print(outputs)
for j, output in enumerate(outputs): for j, output in enumerate(outputs):
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
elif args.backend == "trtllm-serve":
if args.batch_size > 1:
outputs = asyncio.run(send_batch_requests_async(
args.openai_api_base,
args.openai_model_name,
batch["chat_list"],
args.temperature,
args.top_p,
args.top_k,
))
else:
outputs = []
for i, chat in enumerate(batch["chat_list"]):
payload = {
"model": args.openai_model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
response = requests.post(args.openai_api_base, json=payload)
response.raise_for_status()
response_json = response.json()
generated_content = response_json['choices'][0]['message']['content']
outputs.append(generated_content)
llm_end_time = time.time() llm_end_time = time.time()
total_llm_time += (llm_end_time - llm_start_time) total_llm_time += (llm_end_time - llm_start_time)
@@ -459,10 +535,21 @@ def main(args):
items_for_token_2wav = [] items_for_token_2wav = []
for i in range(len(batch["ids"])): for i in range(len(batch["ids"])):
llm_post_processing_start_time = time.time() llm_post_processing_start_time = time.time()
input_length = len(batch["input_ids"][i]) if args.backend == "trtllm-serve":
generated_ids = outputs[i][input_length:] speech_tokens_str = outputs[i].strip().split('><')
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) if len(speech_tokens_str) > 1:
speech_ids = extract_speech_ids(speech_tokens_str) speech_tokens_str = [
t if t.startswith('<') else '<' + t for t in speech_tokens_str
]
speech_tokens_str = [
t if t.endswith('>') else t + '>' for t in speech_tokens_str
]
speech_ids = extract_speech_ids(speech_tokens_str)
else:
input_length = len(batch["input_ids"][i])
generated_ids = outputs[i][input_length:]
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
speech_ids = extract_speech_ids(speech_tokens_str)
print(i, speech_ids) print(i, speech_ids)
if len(speech_ids) == 0: if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
@@ -558,6 +645,8 @@ if __name__ == "__main__":
from tensorrt_llm.runtime import ModelRunnerCpp from tensorrt_llm.runtime import ModelRunnerCpp
elif args.backend == "hf": elif args.backend == "hf":
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
elif args.backend == "trtllm-serve":
pass
else: else:
raise ValueError(f"Unsupported backend: {args.backend}") raise ValueError(f"Unsupported backend: {args.backend}")
main(args) main(args)

View File

@@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=1
cosyvoice_path=/workspace/CosyVoice cosyvoice_path=/workspace/CosyVoice
cosyvoice_path=/workspace_yuekai/tts/CosyVoice cosyvoice_path=/workspace_yuekai/tts/CosyVoice
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
@@ -16,7 +16,7 @@ trt_dtype=bfloat16
trt_weights_dir=./trt_weights_${trt_dtype} trt_weights_dir=./trt_weights_${trt_dtype}
trt_engines_dir=./trt_engines_${trt_dtype} trt_engines_dir=./trt_engines_${trt_dtype}
model_repo=./model_repo_cosyvoice2 model_repo=./model_repo_cosyvoice2_dit
use_spk2info_cache=False use_spk2info_cache=False
@@ -58,40 +58,78 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
--engine_dir=$trt_engines_dir || exit 1 --engine_dir=$trt_engines_dir || exit 1
fi fi
# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# echo "Creating model repository"
# rm -rf $model_repo
# mkdir -p $model_repo
# cosyvoice2_dir="cosyvoice2_dit"
# token2wav_dir="token2wav_dit"
# cp -r ./model_repo/${cosyvoice2_dir} $model_repo
# cp -r ./model_repo/tensorrt_llm $model_repo
# cp -r ./model_repo/${token2wav_dir} $model_repo
# #if [ $use_spk2info_cache == "False" ]; then
# cp -r ./model_repo/audio_tokenizer $model_repo
# cp -r ./model_repo/speaker_embedding $model_repo
# #fi
# ENGINE_PATH=$trt_engines_dir
# MAX_QUEUE_DELAY_MICROSECONDS=0
# MODEL_DIR=$model_scope_model_local_dir
# LLM_TOKENIZER_DIR=$huggingface_model_local_dir
# BLS_INSTANCE_NUM=1
# TRITON_MAX_BATCH_SIZE=16
# DECOUPLED_MODE=True # True for streaming, False for offline
# STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
# python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
# python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
# python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
# #if [ $use_spk2info_cache == "False" ]; then
# python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
# python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
# #fi
# fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Creating model repository" echo "Creating model repository async mode"
rm -rf $model_repo rm -rf $model_repo
mkdir -p $model_repo mkdir -p $model_repo
cosyvoice2_dir="cosyvoice2" cosyvoice2_dir="cosyvoice2_dit"
token2wav_dir="token2wav_dit"
cp -r ./model_repo/${cosyvoice2_dir} $model_repo cp -r ./model_repo/${cosyvoice2_dir} $model_repo
cp -r ./model_repo/tensorrt_llm $model_repo cp -r ./model_repo/tensorrt_llm $model_repo
cp -r ./model_repo/token2wav $model_repo cp -r ./model_repo/${token2wav_dir} $model_repo
if [ $use_spk2info_cache == "False" ]; then #if [ $use_spk2info_cache == "False" ]; then
cp -r ./model_repo/audio_tokenizer $model_repo cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/speaker_embedding $model_repo cp -r ./model_repo/speaker_embedding $model_repo
fi #fi
ENGINE_PATH=$trt_engines_dir ENGINE_PATH=$trt_engines_dir
MAX_QUEUE_DELAY_MICROSECONDS=0 MAX_QUEUE_DELAY_MICROSECONDS=0
MODEL_DIR=$model_scope_model_local_dir MODEL_DIR=$model_scope_model_local_dir
LLM_TOKENIZER_DIR=$huggingface_model_local_dir LLM_TOKENIZER_DIR=$huggingface_model_local_dir
BLS_INSTANCE_NUM=4 BLS_INSTANCE_NUM=4
TRITON_MAX_BATCH_SIZE=16 TRITON_MAX_BATCH_SIZE=32
DECOUPLED_MODE=True # True for streaming, False for offline DECOUPLED_MODE=True # True for streaming, False for offline
STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
if [ $use_spk2info_cache == "False" ]; then #if [ $use_spk2info_cache == "False" ]; then
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
fi #fi
rm -rf $model_repo/tensorrt_llm
# mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Starting Triton server" echo "Starting Triton server"
tritonserver --model-repository $model_repo tritonserver --model-repository $model_repo --http-port 18000
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
@@ -112,26 +150,26 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
python3 client_grpc.py \ python3 client_grpc.py \
--server-addr localhost \ --server-addr localhost \
--model-name cosyvoice2 \ --model-name cosyvoice2_dit \
--num-tasks $num_task \ --num-tasks $num_task \
--mode $mode \ --mode $mode \
--use-spk2info-cache $use_spk2info_cache \
--huggingface-dataset yuekai/seed_tts_cosy2 \ --huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache} --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_no_att_cnn_cache_new
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "stage 6: Offline inference benchmark" echo "stage 6: Offline inference benchmark"
n_gpus=1 n_gpus=1
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
backend=trtllm # hf, trtllm, vllm backend=trtllm-serve # hf, trtllm, vllm
batch_sizes=(16 8 4 2 1) batch_sizes=(16 8 4 2 1)
batch_sizes=(16 8 4 2)
token2wav_batch_size=1 token2wav_batch_size=1
for batch_size in ${batch_sizes[@]}; do for batch_size in ${batch_sizes[@]}; do
for dataset in ${datasets[@]}; do for dataset in ${datasets[@]}; do
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=1 \
python3 offline_inference.py \ python3 offline_inference.py \
--output-dir $output_dir \ --output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \ --llm-model-name-or-path $huggingface_model_local_dir \
@@ -147,7 +185,31 @@ fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
python3 benchmark_streaming_token2wav.py --enable-trt python3 streaming_inference.py
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
#! /usr/bin/env bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "trt_engines_bfloat16",
"messages":[{"role": "user", "content": "Where is New York?"},
{"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
"max_tokens": 512,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"stop": ["<|eos1|>"],
"repetition_penalty": 1.2,
"stream": false
}'
fi fi

View File

@@ -0,0 +1,115 @@
import torch
import os
import argparse
from datasets import load_dataset
from torch.utils.data import DataLoader
import numpy as np
import torchaudio
import time
from token2wav_dit import CosyVoice2_Token2Wav
import soundfile as sf
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
prompt_speech_tokens_list, prompt_text_list = [], []
for i, item in enumerate(batch):
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id'])
prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
prompt_text_list.append(item['prompt_text'])
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
return parser.parse_args()
def fake_generated_id_iter(generated_speech_tokens_list):
for i in range(len(generated_speech_tokens_list)):
yield generated_speech_tokens_list[i]
if __name__ == "__main__":
args = get_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_name = args.dataset_name
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
flow_pre_lookahead_len = 3
CHUNK_SIZE = 25
OVERLAP_SIZE = 0
warmup_times = 3
for _ in range(warmup_times):
start_time = time.time()
for batch in data_loader:
tts_speech_list = []
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
# if id != "unseen3_text5":
# continue
# else:
# a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt")
# generated_speech_tokens = a["semantic_token_ids_arr"]
# print(generated_speech_tokens)
assert prompt_audio_sample_rate == 16000
prompt_text = prompt_text_list[0]
prompt_speech_tokens = prompt_speech_tokens_list[0]
# generated_ids_iter = fake_generated_id_iter(generated_speech_tokens)
semantic_token_ids_arr, token_offset = [], 0
flow_prompt_speech_token_len = len(prompt_speech_tokens)
buffer = generated_speech_tokens
output_wavs = []
while True:
if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len:
wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:]
output_wavs.append(wavs)
else:
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
output_wavs.append(wavs)
break
for i, wav in enumerate(output_wavs):
output_wavs[i] = wav.cpu().numpy().squeeze()
audios = output_wavs
reconstructed_audio = np.concatenate(audios)
# Save reconstructed audio
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
print(f"Saved {id}")
end_time = time.time()
if _ == 0:
token2wav_model.speaker_cache = {}
print(f"Warmup time: {end_time - start_time} seconds")