add streaming dit

This commit is contained in:
yuekaiz
2025-09-24 15:18:01 +08:00
parent 444b7ff5df
commit 482464ea27
10 changed files with 850 additions and 269 deletions

View File

@@ -227,12 +227,11 @@ class TritonPythonModel:
def forward_token2wav(
self,
index: int,
target_speech_tokens: torch.Tensor,
request_id: str,
prompt_speech_tokens: torch.Tensor = None,
prompt_speech_feat: torch.Tensor = None,
prompt_spk_embedding: torch.Tensor = None,
token_offset: int = None,
reference_wav: object,
reference_wav_len: object,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
@@ -246,29 +245,16 @@ class TritonPythonModel:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
inputs_tensor = [target_speech_tokens_tensor]
if token_offset is not None:
assert finalize is not None
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor.append(token_offset_tensor)
inputs_tensor.append(finalize_tensor)
if prompt_spk_embedding is not None:
assert prompt_speech_feat is not None
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav',
model_name='token2wav_dit',
requested_output_names=['waveform'],
inputs=inputs_tensor,
request_id=request_id,
parameters={"priority": index+1},
)
inference_response = inference_request.exec()
@@ -346,8 +332,15 @@ class TritonPythonModel:
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
# prompt_speech_feat = None
# prompt_spk_embedding = None
else:
assert False, "wav is None"
# using pre-cached reference text
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
@@ -391,12 +384,12 @@ class TritonPythonModel:
break
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
this_tts_speech_token, request_id, prompt_speech_tokens,
prompt_speech_feat, prompt_spk_embedding, token_offset, False
chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
@@ -429,8 +422,8 @@ class TritonPythonModel:
else:
time.sleep(0.02)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
@@ -439,17 +432,7 @@ class TritonPythonModel:
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
generated_ids = next(generated_ids_iter)
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
responses.append(inference_response)
raise NotImplementedError("Decoupled mode is not supported")
if not self.decoupled:
return responses

View File

@@ -0,0 +1,438 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import math
import os
import re
import time
from typing import Dict, List, Tuple, Optional, Union
import asyncio
import httpx
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def parse_speech_token_string(response_text: str) -> List[int]:
"""
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
"""
speech_tokens = response_text.strip().split('><')
if len(speech_tokens) > 1:
# Add back the missing '<' and '>' for proper parsing
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
speech_ids = []
for token_str in speech_tokens:
match = re.match(r'<\|s_(\d+)\|>', token_str)
if match:
speech_ids.append(int(match.group(1)))
return speech_ids
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
# self.default_spk_info = spk_info["001"]
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
if isinstance(speech_tokens, torch.Tensor):
# Ensure tensor is on CPU and flattened
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
speech_id_str = ""
for token_id in speech_tokens:
# Convert token ID back to the speech number N
token_num = token_id - ORIGINAL_VOCAB_SIZE
speech_id_str += f"<|s_{token_num}|>"
return speech_id_str
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
"""
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
"""
full_text = f"{reference_text}{target_text}"
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_speech_tokens_str}
]
print(chat)
payload = {
"model": "trt_engines_bfloat16",
"messages": chat,
"max_tokens": 750,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": True,
}
api_base = "http://localhost:8000/v1/chat/completions"
buffer = ""
async with httpx.AsyncClient() as client:
async with client.stream("POST", api_base, json=payload, timeout=None) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
line_data = line[len("data: "):].strip()
if line_data == "[DONE]":
break
try:
json_data = json.loads(line_data)
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
buffer += content
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
except json.JSONDecodeError:
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
continue
# Process any remaining complete tokens in the buffer after the stream ends
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
def forward_token2wav(
self,
index: int,
target_speech_tokens: torch.Tensor,
request_id: str,
reference_wav: object,
reference_wav_len: object,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
prompt_speech_tokens: Prompt speech tokens tensor
prompt_speech_feat: Prompt speech feat tensor
prompt_spk_embedding: Prompt spk embedding tensor
target_speech_tokens: Target speech tokens tensor
Returns:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=['waveform'],
inputs=inputs_tensor,
request_id=request_id,
parameters={"priority": index+1},
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
async def _process_request(self, request):
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
# prompt_speech_feat = None
# prompt_spk_embedding = None
else:
# using pre-cached reference text
assert False, "using pre-cached reference text is not supported"
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
async for generated_ids in self.forward_llm_async(
target_text=target_text,
reference_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
):
if not generated_ids:
break
semantic_token_ids_arr.append(generated_ids)
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
self.logger.log_info(f"multiples: {multiples}")
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
## debug
## save semantic_token_ids_arr and reference_text, target_text to a single json file
# save into a torch .pt
# for i, item in enumerate(semantic_token_ids_arr):
# semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE
# import json
# data = {
# "semantic_token_ids_arr": semantic_token_ids_arr,
# "reference_text": reference_text,
# "target_text": target_text
# }
# with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f:
# torch.save(data, f)
# with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f:
# json.dump(data, f)
# ##
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
raise NotImplementedError("Decoupled mode is not supported")
async def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
tasks = [
asyncio.create_task(self._process_request(request))
for request in requests
]
await asyncio.gather(*tasks)
return None

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
name: "cosyvoice2"
name: "cosyvoice2_dit"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {

View File

@@ -42,6 +42,8 @@ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vl
from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict
import numpy as np
from .token2wav_dit import CosyVoice2_Token2Wav
import hashlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@@ -49,117 +51,19 @@ logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
"""
Generates a unique ID for a torch.Tensor.
Tensors with the same elements and properties will have the same ID.
"""
# Convert tensor to a byte string
tensor_bytes = tensor.numpy().tobytes()
class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
self.model_dir = model_dir
self.fp16 = fp16
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
class CosyVoice2Model:
def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False,
device: str = 'cuda'):
self.device = device
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.flow.half()
# streaming tts config
self.token_hop_len = 25
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
self.speech_window = np.hamming(2 * self.source_cache_len)
self.hift_cache_dict = defaultdict(lambda: None)
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load(self, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
# Create a SHA-256 hash of the byte string
hasher = hashlib.sha256()
hasher.update(tensor_bytes)
return hasher.hexdigest()
class TritonPythonModel:
"""Triton Python model for vocoder.
@@ -183,16 +87,10 @@ class TritonPythonModel:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2(
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
# FIXME: device id settings
self.token2wav_model = CosyVoice2_Token2Wav(
model_dir, enable_trt=True, streaming=True
)
spk_info_path = os.path.join(model_dir, "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_dir}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
@@ -208,66 +106,31 @@ class TritonPythonModel:
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
if prompt_speech_tokens_tensor is not None:
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
else:
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
# shift the speech tokens according to the original vocab size
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens.squeeze().tolist()
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
if token_offset is not None:
token_offset = token_offset.as_numpy().item()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
if not finalize:
stream = True
else:
stream = False
request_id = request.request_id()
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
prompt_token=prompt_speech_tokens,
prompt_feat=prompt_speech_feat,
embedding=prompt_spk_embedding,
token_offset=token_offset,
uuid=request_id,
stream=stream,
finalize=finalize)
if finalize:
self.token2wav_model.model.hift_cache_dict.pop(request_id)
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request_id = request.request_id()
else:
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
audio_hat, _ = self.token2wav_model.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
wav_array = torch.from_numpy(wav_array)
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav)
# wav = wav.to(self.device)
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
generated_wave = audio_hat.squeeze(0).cpu().numpy()

View File

@@ -0,0 +1,537 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 token2wav.py --enable-trt || exit 1
"""
import torch
# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
from flashcosyvoice.modules.hifigan import HiFTGenerator
from flashcosyvoice.utils.audio import mel_spectrogram
import torchaudio.compliance.kaldi as kaldi
import onnxruntime
import s3tokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
import torchaudio
import os
import logging
import argparse
import queue
import time
import numpy as np
from hyperpyyaml import load_hyperpyyaml
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
"""perform fade_in_out in tensor style
"""
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel = fade_in_mel.clone()
fade_in_mel[..., :mel_overlap_len] = \
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
import tensorrt as trt
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if dtype == torch.float16:
config.set_flag(trt.BuilderFlag.FP16)
elif dtype == torch.bfloat16:
config.set_flag(trt.BuilderFlag.BF16)
elif dtype == torch.float32:
config.set_flag(trt.BuilderFlag.FP32)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes
for i in range(len(trt_kwargs['input_names'])):
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
if dtype == torch.float16:
tensor_dtype = trt.DataType.HALF
elif dtype == torch.bfloat16:
tensor_dtype = trt.DataType.BF16
elif dtype == torch.float32:
tensor_dtype = trt.DataType.FLOAT
else:
raise ValueError('invalid dtype {}'.format(dtype))
# set input and output data type
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_tensor.dtype = tensor_dtype
for i in range(network.num_outputs):
output_tensor = network.get_output(i)
output_tensor.dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
# save trt engine
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...")
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
self.device = device
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])
class CosyVoice2_Token2Wav(torch.nn.Module):
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
super().__init__()
self.device_id = device_id
self.device = f"cuda:{device_id}"
with open(f"{model_dir}/flow.yaml", "r") as f:
configs = load_hyperpyyaml(f)
self.flow = configs['flow']
self.dtype = dtype
self.flow.to(self.dtype)
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
self.flow.to(self.device).eval()
self.hift = HiFTGenerator()
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
providers=["CPUExecutionProvider"])
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
gpu="l20"
if enable_trt:
if streaming:
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
1,
self.dtype, streaming)
else:
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
1,
self.dtype)
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
self.streaming_flow_cache = {}
self.speaker_cache = {}
self.mel_cache_len = 8 # hard-coded, 160ms
self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
# hifigan cache for streaming tts
self.hift_cache_dict = {}
def forward_spk_embedding(self, spk_feat):
if isinstance(self.spk_model, onnxruntime.InferenceSession):
return self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device_id):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
output_tensor.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return output_tensor.cpu().numpy().flatten().tolist()
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
opt_batch_size = 2
max_batch_size = 16
if streaming:
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
if streaming:
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)]
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)]
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
else:
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
prompt_speech_mels_list.append(log_mel)
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
)
for i in range(len(prompt_speech_tokens)):
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
prompt_speech_tokens_list.append(speech_tokens_i)
return prompt_speech_tokens_list
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
spk_emb_for_flow = []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
spk_emb = self.forward_spk_embedding(spk_feat)
spk_emb_for_flow.append(spk_emb)
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
if self.dtype != torch.float32:
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
return spk_emb_for_flow
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
prompt_mels_for_flow = []
prompt_mels_lens_for_flow = []
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
assert len(audio.shape) == 1
audio = audio.unsqueeze(0)
if sample_rate != 24000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
mel_len = mel.shape[0]
prompt_mels_for_flow.append(mel)
prompt_mels_lens_for_flow.append(mel_len)
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
return prompt_mels_for_flow, prompt_mels_lens_for_flow
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
batch_size = prompt_mels_for_flow.shape[0]
flow_inputs = []
flow_inputs_lens = []
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
flow_inputs_lens = torch.tensor(flow_inputs_lens)
with torch.amp.autocast(self.device, dtype=torch.float16):
generated_mels, generated_mels_lens = self.flow.inference(
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
)
return generated_mels, generated_mels_lens
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
batch_size = generated_mels.shape[0]
generated_wavs = []
for i in range(batch_size):
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
wav, _ = self.hift(speech_feat=mel)
generated_wavs.append(wav)
return generated_wavs
@torch.inference_mode()
def forward(
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
):
# assert all item in prompt_audios_sample_rate is 16000
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
return generated_wavs
def prepare_prompt_audio(
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
):
# assert all item in prompt_audios_sample_rate is 16000
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
def get_prompt_audio_cache_for_streaming_tts(
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
):
assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
cache = self.flow.setup_cache(
prompt_speech_tokens_tensor.to(self.device),
prompt_mels_for_flow.to(self.device),
spk_emb_for_flow.to(self.device),
n_timesteps=10
)
new_cache = {k: v.clone() for k, v in cache.items()}
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
return new_cache
@torch.inference_mode()
def forward_streaming(
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
):
if speaker_id not in self.speaker_cache:
# if 1:
assert prompt_audio is not None, "prompt_audio is required for new speaker"
assert prompt_audio_sample_rate == 16000
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
# if speaker_id not in self.speaker_cache:
# if 1:
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
print(f"speaker_id {speaker_id} added to cache")
# get a clone of cache dict ['estimator_att_cache'] and later check if it would be change
att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone()
cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone()
conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone()
conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone()
if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict(
mel = torch.zeros(1, 80, 0, device='cuda'),
source = torch.zeros(1, 1, 0, device='cuda'),
speech = torch.zeros(1, 0, device='cuda'),
)
current_request_cache = self.streaming_flow_cache[request_id]
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
token=generated_speech_tokens,
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
cache=current_request_cache,
last_chunk=last_chunk,
n_timesteps=10,
)
# get the original att_cache
original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache']
original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache']
original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache']
original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache']
if not torch.allclose(original_att_cache, att_cache_clone):
print("att_cache changed")
# print the last 10 elements of original_att_cache and att_cache_clone
print(original_att_cache[:, :, :, -10:])
print(att_cache_clone[:, :, :, -10:])
breakpoint()
if not torch.allclose(original_cnn_cache, cnn_cache_clone):
print("cnn_cache changed")
print(original_cnn_cache[..., -10:])
print(cnn_cache_clone[..., -10:])
breakpoint()
if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone):
print("conformer_cnn_cache changed")
print(original_conformer_cnn_cache[..., -10:])
print(conformer_cnn_cache_clone[..., -10:])
breakpoint()
if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone):
print("conformer_att_cache changed")
print(original_conformer_att_cache[..., -10:])
print(conformer_att_cache_clone[..., -10:])
breakpoint()
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
], dim=4)
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
speech, source = self.hift(mel, hift_cache_source)
# overlap speech smooth
if hift_cache_speech.shape[-1] > 0:
speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
# update vocoder cache
self.hift_cache_dict[request_id] = dict(
mel = mel[..., -self.mel_cache_len:].clone().detach(),
source = source[:, :, -self.source_cache_len:].clone().detach(),
speech = speech[:, -self.source_cache_len:].clone().detach(),
)
if not last_chunk:
speech = speech[:, :-self.source_cache_len]
if last_chunk:
assert request_id in self.streaming_flow_cache
self.streaming_flow_cache.pop(request_id)
self.hift_cache_dict.pop(request_id)
# breakpoint()
return speech
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for i, item in enumerate(batch):
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id'])
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
# mkdir output_dir if not exists
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_name = "yuekai/seed_tts_cosy2"
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for epoch in range(args.warmup):
start_time = time.time()
for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
for id, wav in zip(ids, generated_wavs):
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
end_time = time.time()
epoch_time = end_time - start_time
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

View File

@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
name: "token2wav"
name: "token2wav_dit"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
priority_levels: 10
default_priority_level: 10
}
parameters [
{
@@ -32,29 +34,14 @@ input [
dims: [-1]
},
{
name: "prompt_speech_tokens"
data_type: TYPE_INT32
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: true
},
{
name: "prompt_speech_feat"
data_type: TYPE_FP16
dims: [-1, 80]
optional: true
},
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
optional: true
},
{
name: "token_offset"
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
dims: [1]
},
{
name: "finalize"