mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add streaming dit
This commit is contained in:
@@ -227,12 +227,11 @@ class TritonPythonModel:
|
||||
|
||||
def forward_token2wav(
|
||||
self,
|
||||
index: int,
|
||||
target_speech_tokens: torch.Tensor,
|
||||
request_id: str,
|
||||
prompt_speech_tokens: torch.Tensor = None,
|
||||
prompt_speech_feat: torch.Tensor = None,
|
||||
prompt_spk_embedding: torch.Tensor = None,
|
||||
token_offset: int = None,
|
||||
reference_wav: object,
|
||||
reference_wav_len: object,
|
||||
finalize: bool = None) -> torch.Tensor:
|
||||
"""Forward pass through the vocoder component.
|
||||
|
||||
@@ -246,29 +245,16 @@ class TritonPythonModel:
|
||||
Generated waveform tensor
|
||||
"""
|
||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||
|
||||
inputs_tensor = [target_speech_tokens_tensor]
|
||||
|
||||
if token_offset is not None:
|
||||
assert finalize is not None
|
||||
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
|
||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||
inputs_tensor.append(token_offset_tensor)
|
||||
inputs_tensor.append(finalize_tensor)
|
||||
|
||||
if prompt_spk_embedding is not None:
|
||||
assert prompt_speech_feat is not None
|
||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
|
||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
||||
|
||||
# Create and execute inference request
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='token2wav',
|
||||
model_name='token2wav_dit',
|
||||
requested_output_names=['waveform'],
|
||||
inputs=inputs_tensor,
|
||||
request_id=request_id,
|
||||
parameters={"priority": index+1},
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
@@ -346,8 +332,15 @@ class TritonPythonModel:
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
|
||||
# reference_text = self.default_spk_info["prompt_text"]
|
||||
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
# prompt_speech_feat = None
|
||||
# prompt_spk_embedding = None
|
||||
|
||||
else:
|
||||
assert False, "wav is None"
|
||||
# using pre-cached reference text
|
||||
reference_text = self.default_spk_info["prompt_text"]
|
||||
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
@@ -391,12 +384,12 @@ class TritonPythonModel:
|
||||
break
|
||||
|
||||
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
|
||||
sub_tts_speech = self.forward_token2wav(
|
||||
this_tts_speech_token, request_id, prompt_speech_tokens,
|
||||
prompt_speech_feat, prompt_spk_embedding, token_offset, False
|
||||
chunk_index,
|
||||
this_tts_speech_token, request_id, wav, wav_len, False
|
||||
)
|
||||
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
@@ -429,8 +422,8 @@ class TritonPythonModel:
|
||||
else:
|
||||
time.sleep(0.02)
|
||||
|
||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
|
||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
@@ -439,17 +432,7 @@ class TritonPythonModel:
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
generated_ids = next(generated_ids_iter)
|
||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
|
||||
if generated_ids is None or len(generated_ids) == 0:
|
||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||
|
||||
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
|
||||
|
||||
# Prepare response
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
responses.append(inference_response)
|
||||
raise NotImplementedError("Decoupled mode is not supported")
|
||||
|
||||
if not self.decoupled:
|
||||
return responses
|
||||
|
||||
438
runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py
Normal file
438
runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Optional, Union
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import torchaudio
|
||||
|
||||
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def parse_speech_token_string(response_text: str) -> List[int]:
|
||||
"""
|
||||
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
|
||||
"""
|
||||
speech_tokens = response_text.strip().split('><')
|
||||
if len(speech_tokens) > 1:
|
||||
# Add back the missing '<' and '>' for proper parsing
|
||||
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
|
||||
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
|
||||
|
||||
speech_ids = []
|
||||
for token_str in speech_tokens:
|
||||
match = re.match(r'<\|s_(\d+)\|>', token_str)
|
||||
if match:
|
||||
speech_ids.append(int(match.group(1)))
|
||||
return speech_ids
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for Spark TTS.
|
||||
|
||||
This model orchestrates the end-to-end TTS pipeline by coordinating
|
||||
between audio tokenizer, LLM, and vocoder components.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
self.logger = pb_utils.Logger
|
||||
# Parse model parameters
|
||||
self.model_config = json.loads(args['model_config'])
|
||||
parameters = self.model_config['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
self.logger.log_info(f"model_params:{model_params}")
|
||||
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||
|
||||
# Initialize tokenizer
|
||||
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
||||
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
|
||||
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
||||
|
||||
self.token_frame_rate = 25
|
||||
self.flow_pre_lookahead_len = 3
|
||||
self.token_hop_len = 15
|
||||
|
||||
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
|
||||
if not os.path.exists(spk_info_path):
|
||||
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||
# self.default_spk_info = spk_info["001"]
|
||||
|
||||
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
||||
"""Converts a tensor or list of speech token IDs to a string representation."""
|
||||
if isinstance(speech_tokens, torch.Tensor):
|
||||
# Ensure tensor is on CPU and flattened
|
||||
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
|
||||
|
||||
speech_id_str = ""
|
||||
for token_id in speech_tokens:
|
||||
# Convert token ID back to the speech number N
|
||||
token_num = token_id - ORIGINAL_VOCAB_SIZE
|
||||
speech_id_str += f"<|s_{token_num}|>"
|
||||
return speech_id_str
|
||||
|
||||
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
|
||||
"""
|
||||
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
|
||||
"""
|
||||
full_text = f"{reference_text}{target_text}"
|
||||
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
|
||||
|
||||
chat = [
|
||||
{"role": "user", "content": full_text},
|
||||
{"role": "assistant", "content": prompt_speech_tokens_str}
|
||||
]
|
||||
print(chat)
|
||||
|
||||
payload = {
|
||||
"model": "trt_engines_bfloat16",
|
||||
"messages": chat,
|
||||
"max_tokens": 750,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.95,
|
||||
"top_k": 50,
|
||||
"repetition_penalty": 1.1,
|
||||
"stop": ["<|eos1|>", "<|eos|>"],
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
api_base = "http://localhost:8000/v1/chat/completions"
|
||||
|
||||
buffer = ""
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream("POST", api_base, json=payload, timeout=None) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
line_data = line[len("data: "):].strip()
|
||||
if line_data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
json_data = json.loads(line_data)
|
||||
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
||||
if content:
|
||||
buffer += content
|
||||
while True:
|
||||
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||
if not match:
|
||||
break
|
||||
|
||||
token_num = int(match.group(1))
|
||||
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
||||
yield final_id
|
||||
buffer = buffer[match.end():]
|
||||
except json.JSONDecodeError:
|
||||
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
|
||||
continue
|
||||
|
||||
# Process any remaining complete tokens in the buffer after the stream ends
|
||||
while True:
|
||||
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||
if not match:
|
||||
break
|
||||
token_num = int(match.group(1))
|
||||
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
||||
yield final_id
|
||||
buffer = buffer[match.end():]
|
||||
|
||||
|
||||
def forward_audio_tokenizer(self, wav, wav_len):
|
||||
"""Forward pass through the audio tokenizer component.
|
||||
|
||||
Args:
|
||||
wav: Input waveform tensor
|
||||
wav_len: Waveform length tensor
|
||||
|
||||
Returns:
|
||||
Tuple of global and semantic tokens
|
||||
"""
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='audio_tokenizer',
|
||||
requested_output_names=['prompt_speech_tokens'],
|
||||
inputs=[wav, wav_len]
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output tensors
|
||||
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
|
||||
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
||||
|
||||
return prompt_speech_tokens
|
||||
|
||||
def forward_speaker_embedding(self, wav):
|
||||
"""Forward pass through the speaker embedding component.
|
||||
|
||||
Args:
|
||||
wav: Input waveform tensor
|
||||
|
||||
Returns:
|
||||
Prompt speaker embedding tensor
|
||||
"""
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='speaker_embedding',
|
||||
requested_output_names=['prompt_spk_embedding'],
|
||||
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output tensors
|
||||
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
|
||||
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
|
||||
|
||||
return prompt_spk_embedding
|
||||
|
||||
def forward_token2wav(
|
||||
self,
|
||||
index: int,
|
||||
target_speech_tokens: torch.Tensor,
|
||||
request_id: str,
|
||||
reference_wav: object,
|
||||
reference_wav_len: object,
|
||||
finalize: bool = None) -> torch.Tensor:
|
||||
"""Forward pass through the vocoder component.
|
||||
|
||||
Args:
|
||||
prompt_speech_tokens: Prompt speech tokens tensor
|
||||
prompt_speech_feat: Prompt speech feat tensor
|
||||
prompt_spk_embedding: Prompt spk embedding tensor
|
||||
target_speech_tokens: Target speech tokens tensor
|
||||
|
||||
Returns:
|
||||
Generated waveform tensor
|
||||
"""
|
||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
||||
|
||||
# Create and execute inference request
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='token2wav_dit',
|
||||
requested_output_names=['waveform'],
|
||||
inputs=inputs_tensor,
|
||||
request_id=request_id,
|
||||
parameters={"priority": index+1},
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output waveform
|
||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||
|
||||
return waveform
|
||||
|
||||
def _extract_speech_feat(self, speech):
|
||||
speech_feat = mel_spectrogram(
|
||||
speech,
|
||||
n_fft=1920,
|
||||
num_mels=80,
|
||||
sampling_rate=24000,
|
||||
hop_size=480,
|
||||
win_size=1920,
|
||||
fmin=0,
|
||||
fmax=8000).squeeze(
|
||||
dim=0).transpose(
|
||||
0,
|
||||
1).to(
|
||||
self.device)
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
return speech_feat
|
||||
|
||||
async def _process_request(self, request):
|
||||
request_id = request.request_id()
|
||||
# Extract input tensors
|
||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
|
||||
# Process reference audio through audio tokenizer
|
||||
if wav is not None:
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||
|
||||
wav_tensor = wav.as_numpy()
|
||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
|
||||
# reference_text = self.default_spk_info["prompt_text"]
|
||||
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
# prompt_speech_feat = None
|
||||
# prompt_spk_embedding = None
|
||||
|
||||
else:
|
||||
# using pre-cached reference text
|
||||
assert False, "using pre-cached reference text is not supported"
|
||||
reference_text = self.default_spk_info["prompt_text"]
|
||||
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
prompt_speech_feat = None
|
||||
prompt_spk_embedding = None
|
||||
|
||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||
target_text = target_text[0][0].decode('utf-8')
|
||||
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
|
||||
semantic_token_ids_arr = []
|
||||
token_offset, chunk_index = 0, 0
|
||||
start_time = time.time()
|
||||
this_token_hop_len = self.token_hop_len
|
||||
|
||||
async for generated_ids in self.forward_llm_async(
|
||||
target_text=target_text,
|
||||
reference_text=reference_text,
|
||||
prompt_speech_tokens=prompt_speech_tokens,
|
||||
):
|
||||
if not generated_ids:
|
||||
break
|
||||
semantic_token_ids_arr.append(generated_ids)
|
||||
|
||||
while True:
|
||||
pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
|
||||
sub_tts_speech = self.forward_token2wav(
|
||||
chunk_index,
|
||||
this_tts_speech_token, request_id, wav, wav_len, False
|
||||
)
|
||||
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
token_offset += this_token_hop_len
|
||||
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
|
||||
|
||||
if self.dynamic_chunk_strategy == "exponential":
|
||||
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
||||
elif self.dynamic_chunk_strategy == "time_based":
|
||||
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
||||
cost_time = time.time() - start_time
|
||||
duration = token_offset / self.token_frame_rate
|
||||
if chunk_index > 0 and cost_time > 0:
|
||||
avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
||||
if avg_chunk_processing_time > 0:
|
||||
multiples = (duration - cost_time) / avg_chunk_processing_time
|
||||
self.logger.log_info(f"multiples: {multiples}")
|
||||
next_pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
if multiples > 4:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
||||
elif multiples > 2:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
|
||||
else:
|
||||
this_token_hop_len = self.token_hop_len
|
||||
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
||||
chunk_index += 1
|
||||
else:
|
||||
break
|
||||
|
||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
## debug
|
||||
## save semantic_token_ids_arr and reference_text, target_text to a single json file
|
||||
# save into a torch .pt
|
||||
# for i, item in enumerate(semantic_token_ids_arr):
|
||||
# semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE
|
||||
# import json
|
||||
# data = {
|
||||
# "semantic_token_ids_arr": semantic_token_ids_arr,
|
||||
# "reference_text": reference_text,
|
||||
# "target_text": target_text
|
||||
# }
|
||||
# with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f:
|
||||
# torch.save(data, f)
|
||||
# with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f:
|
||||
# json.dump(data, f)
|
||||
|
||||
# ##
|
||||
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
raise NotImplementedError("Decoupled mode is not supported")
|
||||
|
||||
async def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing generated audio
|
||||
"""
|
||||
tasks = [
|
||||
asyncio.create_task(self._process_request(request))
|
||||
for request in requests
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
return None
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "cosyvoice2"
|
||||
name: "cosyvoice2_dit"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
|
||||
@@ -42,6 +42,8 @@ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vl
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from .token2wav_dit import CosyVoice2_Token2Wav
|
||||
import hashlib
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -49,117 +51,19 @@ logger = logging.getLogger(__name__)
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
"""
|
||||
Generates a unique ID for a torch.Tensor.
|
||||
Tensors with the same elements and properties will have the same ID.
|
||||
"""
|
||||
# Convert tensor to a byte string
|
||||
tensor_bytes = tensor.numpy().tobytes()
|
||||
|
||||
class CosyVoice2:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
|
||||
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
|
||||
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
|
||||
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
|
||||
|
||||
class CosyVoice2Model:
|
||||
|
||||
def __init__(self,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False,
|
||||
device: str = 'cuda'):
|
||||
self.device = device
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.flow.half()
|
||||
|
||||
# streaming tts config
|
||||
self.token_hop_len = 25
|
||||
self.mel_cache_len = 8
|
||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
self.hift_cache_dict = defaultdict(lambda: None)
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load(self, flow_model, hift_model):
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
||||
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
||||
input_names = ["x", "mask", "mu", "cond"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
streaming=stream,
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append hift cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||
else:
|
||||
hift_cache_source = torch.zeros(1, 1, 0)
|
||||
# keep overlap mel and hift cache
|
||||
if finalize is False:
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
return tts_speech
|
||||
|
||||
# Create a SHA-256 hash of the byte string
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(tensor_bytes)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
@@ -183,16 +87,10 @@ class TritonPythonModel:
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||
|
||||
self.token2wav_model = CosyVoice2(
|
||||
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
|
||||
# FIXME: device id settings
|
||||
self.token2wav_model = CosyVoice2_Token2Wav(
|
||||
model_dir, enable_trt=True, streaming=True
|
||||
)
|
||||
|
||||
spk_info_path = os.path.join(model_dir, "spk2info.pt")
|
||||
if not os.path.exists(spk_info_path):
|
||||
raise ValueError(f"spk2info.pt not found in {model_dir}")
|
||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||
self.default_spk_info = spk_info["001"]
|
||||
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
|
||||
def execute(self, requests):
|
||||
@@ -208,66 +106,31 @@ class TritonPythonModel:
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
|
||||
|
||||
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
|
||||
if prompt_speech_tokens_tensor is not None:
|
||||
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
|
||||
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
||||
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
||||
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
||||
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
||||
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
else:
|
||||
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
|
||||
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
|
||||
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
|
||||
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
|
||||
# shift the speech tokens according to the original vocab size
|
||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
||||
|
||||
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
|
||||
if token_offset is not None:
|
||||
token_offset = token_offset.as_numpy().item()
|
||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||
if not finalize:
|
||||
stream = True
|
||||
else:
|
||||
stream = False
|
||||
request_id = request.request_id()
|
||||
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
|
||||
prompt_token=prompt_speech_tokens,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=prompt_spk_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=request_id,
|
||||
stream=stream,
|
||||
finalize=finalize)
|
||||
if finalize:
|
||||
self.token2wav_model.model.hift_cache_dict.pop(request_id)
|
||||
|
||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||
|
||||
request_id = request.request_id()
|
||||
|
||||
|
||||
else:
|
||||
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
||||
token=target_speech_tokens,
|
||||
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
||||
self.device
|
||||
),
|
||||
prompt_token=prompt_speech_tokens,
|
||||
prompt_token_len=torch.tensor(
|
||||
[prompt_speech_tokens.shape[1]], dtype=torch.int32
|
||||
).to(self.device),
|
||||
prompt_feat=prompt_speech_feat,
|
||||
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=prompt_spk_embedding,
|
||||
streaming=False,
|
||||
finalize=True,
|
||||
)
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav_len").as_numpy().item()
|
||||
|
||||
audio_hat, _ = self.token2wav_model.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
wav_array = torch.from_numpy(wav_array)
|
||||
# Prepare inputs
|
||||
wav = wav_array[:, :wav_len].squeeze(0)
|
||||
|
||||
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||
# wav = wav.to(self.device)
|
||||
|
||||
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
|
||||
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
|
||||
|
||||
@@ -0,0 +1,537 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 token2wav.py --enable-trt || exit 1
|
||||
"""
|
||||
import torch
|
||||
# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
|
||||
from flashcosyvoice.modules.hifigan import HiFTGenerator
|
||||
from flashcosyvoice.utils.audio import mel_spectrogram
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import onnxruntime
|
||||
import s3tokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
import torchaudio
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import queue
|
||||
import time
|
||||
import numpy as np
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
|
||||
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
|
||||
"""perform fade_in_out in tensor style
|
||||
"""
|
||||
mel_overlap_len = int(window.shape[0] / 2)
|
||||
fade_in_mel = fade_in_mel.clone()
|
||||
fade_in_mel[..., :mel_overlap_len] = \
|
||||
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
||||
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
return fade_in_mel
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
|
||||
import tensorrt as trt
|
||||
logging.info("Converting onnx to trt...")
|
||||
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if dtype == torch.float16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
elif dtype == torch.bfloat16:
|
||||
config.set_flag(trt.BuilderFlag.BF16)
|
||||
elif dtype == torch.float32:
|
||||
config.set_flag(trt.BuilderFlag.FP32)
|
||||
profile = builder.create_optimization_profile()
|
||||
# load onnx model
|
||||
with open(onnx_model, "rb") as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||
# set input shapes
|
||||
for i in range(len(trt_kwargs['input_names'])):
|
||||
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||
if dtype == torch.float16:
|
||||
tensor_dtype = trt.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
tensor_dtype = trt.DataType.BF16
|
||||
elif dtype == torch.float32:
|
||||
tensor_dtype = trt.DataType.FLOAT
|
||||
else:
|
||||
raise ValueError('invalid dtype {}'.format(dtype))
|
||||
# set input and output data type
|
||||
for i in range(network.num_inputs):
|
||||
input_tensor = network.get_input(i)
|
||||
input_tensor.dtype = tensor_dtype
|
||||
for i in range(network.num_outputs):
|
||||
output_tensor = network.get_output(i)
|
||||
output_tensor.dtype = tensor_dtype
|
||||
config.add_optimization_profile(profile)
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
# save trt engine
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
self.device = device
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
|
||||
super().__init__()
|
||||
self.device_id = device_id
|
||||
self.device = f"cuda:{device_id}"
|
||||
with open(f"{model_dir}/flow.yaml", "r") as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
self.flow = configs['flow']
|
||||
|
||||
self.dtype = dtype
|
||||
self.flow.to(self.dtype)
|
||||
|
||||
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
|
||||
self.hift = HiFTGenerator()
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
|
||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
|
||||
|
||||
gpu="l20"
|
||||
if enable_trt:
|
||||
if streaming:
|
||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
|
||||
1,
|
||||
self.dtype, streaming)
|
||||
else:
|
||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||
1,
|
||||
self.dtype)
|
||||
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False)
|
||||
|
||||
|
||||
self.streaming_flow_cache = {}
|
||||
self.speaker_cache = {}
|
||||
|
||||
self.mel_cache_len = 8 # hard-coded, 160ms
|
||||
self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
|
||||
self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
|
||||
|
||||
# hifigan cache for streaming tts
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def forward_spk_embedding(self, spk_feat):
|
||||
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||
return self.spk_model.run(
|
||||
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
||||
)[0].flatten().tolist()
|
||||
else:
|
||||
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
|
||||
# NOTE need to synchronize when switching stream
|
||||
with torch.cuda.device(self.device_id):
|
||||
torch.cuda.current_stream().synchronize()
|
||||
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
|
||||
batch_size = spk_feat.size(0)
|
||||
|
||||
with stream:
|
||||
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
|
||||
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
|
||||
|
||||
data_ptrs = [spk_feat.contiguous().data_ptr(),
|
||||
output_tensor.contiguous().data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
|
||||
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.spk_model.release_estimator(spk_model, stream)
|
||||
|
||||
return output_tensor.cpu().numpy().flatten().tolist()
|
||||
|
||||
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||
trt_kwargs = self.get_spk_trt_kwargs()
|
||||
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
|
||||
import tensorrt as trt
|
||||
with open(spk_model, 'rb') as f:
|
||||
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
|
||||
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_spk_trt_kwargs(self):
|
||||
min_shape = [(1, 4, 80)]
|
||||
opt_shape = [(1, 500, 80)]
|
||||
max_shape = [(1, 3000, 80)]
|
||||
input_names = ["input"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
opt_batch_size = 2
|
||||
max_batch_size = 16
|
||||
if streaming:
|
||||
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
|
||||
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
|
||||
if streaming:
|
||||
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
|
||||
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)]
|
||||
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)]
|
||||
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
|
||||
else:
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
|
||||
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
|
||||
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
|
||||
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
|
||||
for audio in prompt_audios_list:
|
||||
assert len(audio.shape) == 1
|
||||
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
|
||||
prompt_speech_mels_list.append(log_mel)
|
||||
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
|
||||
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
|
||||
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
|
||||
)
|
||||
for i in range(len(prompt_speech_tokens)):
|
||||
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||
return prompt_speech_tokens_list
|
||||
|
||||
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||
spk_emb_for_flow = []
|
||||
for audio in prompt_audios_list:
|
||||
assert len(audio.shape) == 1
|
||||
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
||||
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
||||
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||
|
||||
spk_emb_for_flow.append(spk_emb)
|
||||
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||
if self.dtype != torch.float32:
|
||||
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
|
||||
return spk_emb_for_flow
|
||||
|
||||
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||
prompt_mels_for_flow = []
|
||||
prompt_mels_lens_for_flow = []
|
||||
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
|
||||
assert len(audio.shape) == 1
|
||||
audio = audio.unsqueeze(0)
|
||||
if sample_rate != 24000:
|
||||
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
||||
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
||||
mel_len = mel.shape[0]
|
||||
prompt_mels_for_flow.append(mel)
|
||||
prompt_mels_lens_for_flow.append(mel_len)
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||
batch_size = prompt_mels_for_flow.shape[0]
|
||||
flow_inputs = []
|
||||
flow_inputs_lens = []
|
||||
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
|
||||
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
|
||||
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
|
||||
|
||||
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
|
||||
flow_inputs_lens = torch.tensor(flow_inputs_lens)
|
||||
|
||||
with torch.amp.autocast(self.device, dtype=torch.float16):
|
||||
generated_mels, generated_mels_lens = self.flow.inference(
|
||||
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
|
||||
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
|
||||
)
|
||||
|
||||
return generated_mels, generated_mels_lens
|
||||
|
||||
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
|
||||
batch_size = generated_mels.shape[0]
|
||||
generated_wavs = []
|
||||
for i in range(batch_size):
|
||||
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
|
||||
wav, _ = self.hift(speech_feat=mel)
|
||||
generated_wavs.append(wav)
|
||||
return generated_wavs
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
# assert all item in prompt_audios_sample_rate is 16000
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
|
||||
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
|
||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||
|
||||
return generated_wavs
|
||||
|
||||
def prepare_prompt_audio(
|
||||
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
# assert all item in prompt_audios_sample_rate is 16000
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
|
||||
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||
|
||||
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||
|
||||
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
|
||||
|
||||
def get_prompt_audio_cache_for_streaming_tts(
|
||||
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
):
|
||||
assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
|
||||
for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
|
||||
prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
|
||||
prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
|
||||
|
||||
cache = self.flow.setup_cache(
|
||||
prompt_speech_tokens_tensor.to(self.device),
|
||||
prompt_mels_for_flow.to(self.device),
|
||||
spk_emb_for_flow.to(self.device),
|
||||
n_timesteps=10
|
||||
)
|
||||
new_cache = {k: v.clone() for k, v in cache.items()}
|
||||
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
|
||||
return new_cache
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_streaming(
|
||||
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||
):
|
||||
if speaker_id not in self.speaker_cache:
|
||||
# if 1:
|
||||
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
||||
assert prompt_audio_sample_rate == 16000
|
||||
|
||||
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
|
||||
|
||||
token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
|
||||
prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
|
||||
prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
|
||||
|
||||
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
|
||||
|
||||
# if speaker_id not in self.speaker_cache:
|
||||
# if 1:
|
||||
|
||||
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
|
||||
print(f"speaker_id {speaker_id} added to cache")
|
||||
|
||||
# get a clone of cache dict ['estimator_att_cache'] and later check if it would be change
|
||||
att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone()
|
||||
cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone()
|
||||
conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone()
|
||||
conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone()
|
||||
|
||||
|
||||
if request_id not in self.streaming_flow_cache:
|
||||
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel = torch.zeros(1, 80, 0, device='cuda'),
|
||||
source = torch.zeros(1, 1, 0, device='cuda'),
|
||||
speech = torch.zeros(1, 0, device='cuda'),
|
||||
)
|
||||
|
||||
current_request_cache = self.streaming_flow_cache[request_id]
|
||||
|
||||
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
||||
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
||||
|
||||
|
||||
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
||||
token=generated_speech_tokens,
|
||||
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
||||
cache=current_request_cache,
|
||||
last_chunk=last_chunk,
|
||||
n_timesteps=10,
|
||||
)
|
||||
|
||||
# get the original att_cache
|
||||
original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache']
|
||||
original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache']
|
||||
original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache']
|
||||
original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache']
|
||||
if not torch.allclose(original_att_cache, att_cache_clone):
|
||||
print("att_cache changed")
|
||||
# print the last 10 elements of original_att_cache and att_cache_clone
|
||||
print(original_att_cache[:, :, :, -10:])
|
||||
print(att_cache_clone[:, :, :, -10:])
|
||||
breakpoint()
|
||||
if not torch.allclose(original_cnn_cache, cnn_cache_clone):
|
||||
print("cnn_cache changed")
|
||||
print(original_cnn_cache[..., -10:])
|
||||
print(cnn_cache_clone[..., -10:])
|
||||
breakpoint()
|
||||
if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone):
|
||||
print("conformer_cnn_cache changed")
|
||||
print(original_conformer_cnn_cache[..., -10:])
|
||||
print(conformer_cnn_cache_clone[..., -10:])
|
||||
breakpoint()
|
||||
if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone):
|
||||
print("conformer_att_cache changed")
|
||||
print(original_conformer_att_cache[..., -10:])
|
||||
print(conformer_att_cache_clone[..., -10:])
|
||||
breakpoint()
|
||||
|
||||
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
||||
|
||||
|
||||
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
||||
], dim=4)
|
||||
|
||||
|
||||
|
||||
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
|
||||
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
|
||||
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
|
||||
mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
|
||||
|
||||
speech, source = self.hift(mel, hift_cache_source)
|
||||
|
||||
# overlap speech smooth
|
||||
if hift_cache_speech.shape[-1] > 0:
|
||||
speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
|
||||
|
||||
# update vocoder cache
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel = mel[..., -self.mel_cache_len:].clone().detach(),
|
||||
source = source[:, :, -self.source_cache_len:].clone().detach(),
|
||||
speech = speech[:, -self.source_cache_len:].clone().detach(),
|
||||
)
|
||||
if not last_chunk:
|
||||
speech = speech[:, :-self.source_cache_len]
|
||||
|
||||
if last_chunk:
|
||||
assert request_id in self.streaming_flow_cache
|
||||
self.streaming_flow_cache.pop(request_id)
|
||||
self.hift_cache_dict.pop(request_id)
|
||||
# breakpoint()
|
||||
return speech
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||
prompt_audios_list.append(audio)
|
||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||
ids.append(item['id'])
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||
# mkdir output_dir if not exists
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
dataset_name = "yuekai/seed_tts_cosy2"
|
||||
|
||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||
|
||||
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
|
||||
for epoch in range(args.warmup):
|
||||
start_time = time.time()
|
||||
|
||||
for batch in data_loader:
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||
|
||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
|
||||
for id, wav in zip(ids, generated_wavs):
|
||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||
|
||||
end_time = time.time()
|
||||
epoch_time = end_time - start_time
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
@@ -12,11 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "token2wav"
|
||||
name: "token2wav_dit"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
priority_levels: 10
|
||||
default_priority_level: 10
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
@@ -32,29 +34,14 @@ input [
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "prompt_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "prompt_speech_feat"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1, 80]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "prompt_spk_embedding"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "token_offset"
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
dims: [1]
|
||||
},
|
||||
{
|
||||
name: "finalize"
|
||||
|
||||
Reference in New Issue
Block a user