mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
clean code
This commit is contained in:
@@ -28,9 +28,10 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Tuple, Optional, Union
|
from typing import Dict, List, Tuple, Optional, Union
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -42,11 +43,30 @@ import torchaudio
|
|||||||
|
|
||||||
|
|
||||||
from matcha.utils.audio import mel_spectrogram
|
from matcha.utils.audio import mel_spectrogram
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
ORIGINAL_VOCAB_SIZE = 151663
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_speech_token_string(response_text: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
|
||||||
|
"""
|
||||||
|
speech_tokens = response_text.strip().split('><')
|
||||||
|
if len(speech_tokens) > 1:
|
||||||
|
# Add back the missing '<' and '>' for proper parsing
|
||||||
|
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
|
||||||
|
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
|
||||||
|
|
||||||
|
speech_ids = []
|
||||||
|
for token_str in speech_tokens:
|
||||||
|
match = re.match(r'<\|s_(\d+)\|>', token_str)
|
||||||
|
if match:
|
||||||
|
speech_ids.append(int(match.group(1)))
|
||||||
|
return speech_ids
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for Spark TTS.
|
"""Triton Python model for Spark TTS.
|
||||||
|
|
||||||
@@ -67,6 +87,7 @@ class TritonPythonModel:
|
|||||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||||
self.logger.log_info(f"model_params:{model_params}")
|
self.logger.log_info(f"model_params:{model_params}")
|
||||||
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||||
|
# self.dynamic_chunk_strategy = "equal"
|
||||||
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
@@ -87,92 +108,86 @@ class TritonPythonModel:
|
|||||||
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
||||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||||
self.default_spk_info = spk_info["001"]
|
self.default_spk_info = spk_info["001"]
|
||||||
|
self.http_client = httpx.AsyncClient()
|
||||||
|
|
||||||
def forward_llm(self, input_ids):
|
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
||||||
|
"""Converts a tensor or list of speech token IDs to a string representation."""
|
||||||
|
if isinstance(speech_tokens, torch.Tensor):
|
||||||
|
# Ensure tensor is on CPU and flattened
|
||||||
|
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
|
||||||
|
|
||||||
|
speech_id_str = ""
|
||||||
|
for token_id in speech_tokens:
|
||||||
|
# Convert token ID back to the speech number N
|
||||||
|
token_num = token_id - ORIGINAL_VOCAB_SIZE
|
||||||
|
speech_id_str += f"<|s_{token_num}|>"
|
||||||
|
return speech_id_str
|
||||||
|
|
||||||
|
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
|
||||||
"""
|
"""
|
||||||
Prepares the response from the language model based on the provided
|
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
|
||||||
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
|
||||||
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
|
||||||
For each response from the language model:
|
|
||||||
- Checks for errors and raise an exception if any are found.
|
|
||||||
- Extracts the "output_ids" tensor from the response.
|
|
||||||
- Determines the finish reason based on the presence of the
|
|
||||||
end-of-sequence token or reaching the maximum length.
|
|
||||||
- Appends the generated token IDs to `output_ids`.
|
|
||||||
- If the finish reason is determined, decodes the output IDs to text
|
|
||||||
and prepares the final response.
|
|
||||||
|
|
||||||
The final response includes the generated text, finish reason,
|
|
||||||
completion tokens, prompt tokens, and total tokens.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
|
||||||
"""
|
"""
|
||||||
# convert input_ids to numpy, with shape [1, sequence_length]
|
full_text = f"{reference_text}{target_text}"
|
||||||
input_ids = input_ids.cpu().numpy()
|
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
|
||||||
max_tokens = 750
|
|
||||||
input_dict = {
|
|
||||||
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
|
||||||
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
|
||||||
"pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
|
||||||
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
|
|
||||||
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
|
||||||
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
|
||||||
"temperature": np.array([[0.8]], dtype=np.float32),
|
|
||||||
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
|
|
||||||
"random_seed": np.array([[42]], dtype=np.uint64),
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Convert inputs to Triton tensors
|
chat = [
|
||||||
input_tensor_list = [
|
{"role": "user", "content": full_text},
|
||||||
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
{"role": "assistant", "content": prompt_speech_tokens_str}
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create and execute inference request
|
payload = {
|
||||||
llm_request = pb_utils.InferenceRequest(
|
"model": "trt_engines_bfloat16",
|
||||||
model_name="tensorrt_llm",
|
"messages": chat,
|
||||||
requested_output_names=["output_ids", "sequence_length"],
|
"max_tokens": 750,
|
||||||
inputs=input_tensor_list,
|
"temperature": 0.8,
|
||||||
)
|
"top_p": 0.95,
|
||||||
|
"top_k": 50,
|
||||||
|
"repetition_penalty": 1.1,
|
||||||
|
"stop": ["<|eos1|>", "<|eos|>"],
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
llm_responses = llm_request.exec(decoupled=self.decoupled)
|
api_base = "http://localhost:8000/v1/chat/completions"
|
||||||
if self.decoupled:
|
|
||||||
for llm_response in llm_responses:
|
|
||||||
if llm_response.has_error():
|
|
||||||
raise pb_utils.TritonModelException(llm_response.error().message())
|
|
||||||
|
|
||||||
# Extract and process output
|
buffer = ""
|
||||||
output_ids = pb_utils.get_output_tensor_by_name(
|
async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
|
||||||
llm_response, "output_ids").as_numpy()
|
print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||||
seq_lens = pb_utils.get_output_tensor_by_name(
|
print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||||
llm_response, "sequence_length").as_numpy()
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line_data = line[len("data: "):].strip()
|
||||||
|
if line_data == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
json_data = json.loads(line_data)
|
||||||
|
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
||||||
|
if content:
|
||||||
|
buffer += content
|
||||||
|
print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||||
|
while True:
|
||||||
|
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||||
|
if not match:
|
||||||
|
break
|
||||||
|
|
||||||
# Get actual output IDs up to the sequence length
|
token_num = int(match.group(1))
|
||||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
||||||
|
yield final_id
|
||||||
|
buffer = buffer[match.end():]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
|
||||||
|
continue
|
||||||
|
|
||||||
yield actual_output_ids
|
# Process any remaining complete tokens in the buffer after the stream ends
|
||||||
else:
|
while True:
|
||||||
llm_response = llm_responses
|
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||||
if llm_response.has_error():
|
if not match:
|
||||||
raise pb_utils.TritonModelException(llm_response.error().message())
|
break
|
||||||
|
token_num = int(match.group(1))
|
||||||
|
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
||||||
|
yield final_id
|
||||||
|
buffer = buffer[match.end():]
|
||||||
|
|
||||||
# Extract and process output
|
|
||||||
output_ids = pb_utils.get_output_tensor_by_name(
|
|
||||||
llm_response, "output_ids").as_numpy()
|
|
||||||
seq_lens = pb_utils.get_output_tensor_by_name(
|
|
||||||
llm_response, "sequence_length").as_numpy()
|
|
||||||
|
|
||||||
# Get actual output IDs up to the sequence length
|
|
||||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
|
||||||
|
|
||||||
yield actual_output_ids
|
|
||||||
|
|
||||||
def forward_audio_tokenizer(self, wav, wav_len):
|
def forward_audio_tokenizer(self, wav, wav_len):
|
||||||
"""Forward pass through the audio tokenizer component.
|
"""Forward pass through the audio tokenizer component.
|
||||||
@@ -225,7 +240,7 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
return prompt_spk_embedding
|
return prompt_spk_embedding
|
||||||
|
|
||||||
def forward_token2wav(
|
async def forward_token2wav(
|
||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
target_speech_tokens: torch.Tensor,
|
target_speech_tokens: torch.Tensor,
|
||||||
@@ -247,17 +262,19 @@ class TritonPythonModel:
|
|||||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||||
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
||||||
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
inference_request = pb_utils.InferenceRequest(
|
inference_request = pb_utils.InferenceRequest(
|
||||||
model_name='token2wav_dit',
|
model_name='token2wav_dit',
|
||||||
requested_output_names=['waveform'],
|
requested_output_names=[
|
||||||
|
"waveform",
|
||||||
|
],
|
||||||
inputs=inputs_tensor,
|
inputs=inputs_tensor,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
parameters={"priority": index+1},
|
parameters={"priority": index+1},
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_response = inference_request.exec()
|
inference_response = await inference_request.async_exec()
|
||||||
if inference_response.has_error():
|
if inference_response.has_error():
|
||||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||||
|
|
||||||
@@ -267,14 +284,6 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
def parse_input(self, text, prompt_text, prompt_speech_tokens):
|
|
||||||
total_text = f"{prompt_text}{text}"
|
|
||||||
prompt = self.prompt_template.format(input_text=total_text)
|
|
||||||
input_ids = self.tokenizer.encode(prompt)
|
|
||||||
input_ids = torch.tensor([input_ids], dtype=torch.int32)
|
|
||||||
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
def _extract_speech_feat(self, speech):
|
def _extract_speech_feat(self, speech):
|
||||||
speech_feat = mel_spectrogram(
|
speech_feat = mel_spectrogram(
|
||||||
speech,
|
speech,
|
||||||
@@ -292,106 +301,75 @@ class TritonPythonModel:
|
|||||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||||
return speech_feat
|
return speech_feat
|
||||||
|
|
||||||
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
|
async def _process_request(self, request):
|
||||||
for generated_ids in generated_ids_iter:
|
request_id = request.request_id()
|
||||||
generated_ids = generated_ids.tolist()
|
# Extract input tensors
|
||||||
if len(generated_ids) == 0:
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||||
break
|
|
||||||
semantic_token_ids_arr.extend(generated_ids)
|
|
||||||
llm_is_done_flag[0] = True
|
|
||||||
|
|
||||||
def execute(self, requests):
|
# Process reference audio through audio tokenizer
|
||||||
"""Execute inference on the batched requests.
|
if wav is not None:
|
||||||
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||||
|
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||||
|
|
||||||
Args:
|
wav_tensor = wav.as_numpy()
|
||||||
requests: List of inference requests
|
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||||
|
print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
|
||||||
|
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||||
|
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||||
|
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||||
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||||
|
|
||||||
Returns:
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
List of inference responses containing generated audio
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
"""
|
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||||
responses = []
|
|
||||||
|
|
||||||
for request in requests:
|
# reference_text = self.default_spk_info["prompt_text"]
|
||||||
request_id = request.request_id()
|
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||||
# Extract input tensors
|
# prompt_speech_feat = None
|
||||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
# prompt_spk_embedding = None
|
||||||
|
|
||||||
# Process reference audio through audio tokenizer
|
else:
|
||||||
if wav is not None:
|
# using pre-cached reference text
|
||||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
assert False, "using pre-cached reference text is not supported"
|
||||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
reference_text = self.default_spk_info["prompt_text"]
|
||||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||||
|
prompt_speech_feat = None
|
||||||
|
prompt_spk_embedding = None
|
||||||
|
|
||||||
wav_tensor = wav.as_numpy()
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
target_text = target_text[0][0].decode('utf-8')
|
||||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
print(f"target_text: {target_text}, time: {datetime.now()}")
|
||||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
||||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
||||||
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
if self.decoupled:
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
response_sender = request.get_response_sender()
|
||||||
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
||||||
|
|
||||||
# reference_text = self.default_spk_info["prompt_text"]
|
semantic_token_ids_arr = []
|
||||||
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
token_offset, chunk_index = 0, 0
|
||||||
# prompt_speech_feat = None
|
start_time = time.time()
|
||||||
# prompt_spk_embedding = None
|
this_token_hop_len = self.token_hop_len
|
||||||
|
print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||||
else:
|
async for generated_ids in self.forward_llm_async(
|
||||||
assert False, "wav is None"
|
target_text=target_text,
|
||||||
# using pre-cached reference text
|
reference_text=reference_text,
|
||||||
reference_text = self.default_spk_info["prompt_text"]
|
|
||||||
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
||||||
prompt_speech_feat = None
|
|
||||||
prompt_spk_embedding = None
|
|
||||||
|
|
||||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
||||||
target_text = target_text[0][0].decode('utf-8')
|
|
||||||
|
|
||||||
# Prepare prompt for LLM
|
|
||||||
input_ids = self.parse_input(
|
|
||||||
text=target_text,
|
|
||||||
prompt_text=reference_text,
|
|
||||||
prompt_speech_tokens=prompt_speech_tokens,
|
prompt_speech_tokens=prompt_speech_tokens,
|
||||||
)
|
):
|
||||||
|
if not generated_ids:
|
||||||
# Generate semantic tokens with LLM
|
break
|
||||||
generated_ids_iter = self.forward_llm(input_ids)
|
semantic_token_ids_arr.append(generated_ids)
|
||||||
|
print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||||
if self.decoupled:
|
|
||||||
response_sender = request.get_response_sender()
|
|
||||||
|
|
||||||
semantic_token_ids_arr = []
|
|
||||||
llm_is_done_flag = [False]
|
|
||||||
|
|
||||||
llm_thread = threading.Thread(
|
|
||||||
target=self._llm_gen_thread,
|
|
||||||
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_thread.start()
|
|
||||||
|
|
||||||
token_offset, chunk_index = 0, 0
|
|
||||||
start_time = time.time()
|
|
||||||
this_token_hop_len = self.token_hop_len
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
pending_num = len(semantic_token_ids_arr) - token_offset
|
pending_num = len(semantic_token_ids_arr) - token_offset
|
||||||
|
|
||||||
if llm_is_done_flag[0]:
|
|
||||||
break
|
|
||||||
|
|
||||||
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||||
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||||
sub_tts_speech = self.forward_token2wav(
|
sub_tts_speech = await self.forward_token2wav(
|
||||||
chunk_index,
|
chunk_index,
|
||||||
this_tts_speech_token, request_id, wav, wav_len, False
|
this_tts_speech_token, request_id, wav, wav_len, False
|
||||||
)
|
)
|
||||||
|
print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
response_sender.send(inference_response)
|
response_sender.send(inference_response)
|
||||||
@@ -401,6 +379,8 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
if self.dynamic_chunk_strategy == "exponential":
|
if self.dynamic_chunk_strategy == "exponential":
|
||||||
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
||||||
|
elif self.dynamic_chunk_strategy == "equal":
|
||||||
|
this_token_hop_len = self.token_hop_len
|
||||||
elif self.dynamic_chunk_strategy == "time_based":
|
elif self.dynamic_chunk_strategy == "time_based":
|
||||||
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
||||||
cost_time = time.time() - start_time
|
cost_time = time.time() - start_time
|
||||||
@@ -420,19 +400,36 @@ class TritonPythonModel:
|
|||||||
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
else:
|
else:
|
||||||
time.sleep(0.02)
|
break
|
||||||
|
|
||||||
|
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
||||||
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
|
response_sender.send(inference_response)
|
||||||
|
|
||||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||||
sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
else:
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
raise NotImplementedError("Decoupled mode is not supported")
|
||||||
response_sender.send(inference_response)
|
|
||||||
|
|
||||||
llm_thread.join()
|
async def execute(self, requests):
|
||||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
"""Execute inference on the batched requests.
|
||||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Decoupled mode is not supported")
|
|
||||||
|
|
||||||
if not self.decoupled:
|
Args:
|
||||||
return responses
|
requests: List of inference requests
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of inference responses containing generated audio
|
||||||
|
"""
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(self._process_request(request))
|
||||||
|
for request in requests
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
self.logger.log_info("Finalizing CosyVoice DIT model")
|
||||||
|
if hasattr(self, "http_client"):
|
||||||
|
asyncio.run(self.http_client.aclose())
|
||||||
|
|||||||
@@ -1,435 +0,0 @@
|
|||||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
||||||
#
|
|
||||||
# Redistribution and use in source and binary forms, with or without
|
|
||||||
# modification, are permitted provided that the following conditions
|
|
||||||
# are met:
|
|
||||||
# * Redistributions of source code must retain the above copyright
|
|
||||||
# notice, this list of conditions and the following disclaimer.
|
|
||||||
# * Redistributions in binary form must reproduce the above copyright
|
|
||||||
# notice, this list of conditions and the following disclaimer in the
|
|
||||||
# documentation and/or other materials provided with the distribution.
|
|
||||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
|
||||||
# contributors may be used to endorse or promote products derived
|
|
||||||
# from this software without specific prior written permission.
|
|
||||||
#
|
|
||||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
|
||||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
||||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
|
||||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
||||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
||||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
||||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
|
||||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
||||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Tuple, Optional, Union
|
|
||||||
import asyncio
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
|
||||||
import triton_python_backend_utils as pb_utils
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
|
|
||||||
from matcha.utils.audio import mel_spectrogram
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
ORIGINAL_VOCAB_SIZE = 151663
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_speech_token_string(response_text: str) -> List[int]:
|
|
||||||
"""
|
|
||||||
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
|
|
||||||
"""
|
|
||||||
speech_tokens = response_text.strip().split('><')
|
|
||||||
if len(speech_tokens) > 1:
|
|
||||||
# Add back the missing '<' and '>' for proper parsing
|
|
||||||
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
|
|
||||||
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
|
|
||||||
|
|
||||||
speech_ids = []
|
|
||||||
for token_str in speech_tokens:
|
|
||||||
match = re.match(r'<\|s_(\d+)\|>', token_str)
|
|
||||||
if match:
|
|
||||||
speech_ids.append(int(match.group(1)))
|
|
||||||
return speech_ids
|
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
|
||||||
"""Triton Python model for Spark TTS.
|
|
||||||
|
|
||||||
This model orchestrates the end-to-end TTS pipeline by coordinating
|
|
||||||
between audio tokenizer, LLM, and vocoder components.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def initialize(self, args):
|
|
||||||
"""Initialize the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Dictionary containing model configuration
|
|
||||||
"""
|
|
||||||
self.logger = pb_utils.Logger
|
|
||||||
# Parse model parameters
|
|
||||||
self.model_config = json.loads(args['model_config'])
|
|
||||||
parameters = self.model_config['parameters']
|
|
||||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
|
||||||
self.logger.log_info(f"model_params:{model_params}")
|
|
||||||
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
|
||||||
# self.dynamic_chunk_strategy = "equal"
|
|
||||||
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
|
||||||
|
|
||||||
# Initialize tokenizer
|
|
||||||
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
|
||||||
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
|
|
||||||
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
|
|
||||||
|
|
||||||
self.device = torch.device("cuda")
|
|
||||||
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
|
||||||
|
|
||||||
self.token_frame_rate = 25
|
|
||||||
self.flow_pre_lookahead_len = 3
|
|
||||||
self.token_hop_len = 15
|
|
||||||
|
|
||||||
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
|
|
||||||
if not os.path.exists(spk_info_path):
|
|
||||||
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
|
||||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
|
||||||
self.default_spk_info = spk_info["001"]
|
|
||||||
self.http_client = httpx.AsyncClient()
|
|
||||||
|
|
||||||
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
|
||||||
"""Converts a tensor or list of speech token IDs to a string representation."""
|
|
||||||
if isinstance(speech_tokens, torch.Tensor):
|
|
||||||
# Ensure tensor is on CPU and flattened
|
|
||||||
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
|
|
||||||
|
|
||||||
speech_id_str = ""
|
|
||||||
for token_id in speech_tokens:
|
|
||||||
# Convert token ID back to the speech number N
|
|
||||||
token_num = token_id - ORIGINAL_VOCAB_SIZE
|
|
||||||
speech_id_str += f"<|s_{token_num}|>"
|
|
||||||
return speech_id_str
|
|
||||||
|
|
||||||
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
|
|
||||||
"""
|
|
||||||
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
|
|
||||||
"""
|
|
||||||
full_text = f"{reference_text}{target_text}"
|
|
||||||
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
|
|
||||||
|
|
||||||
chat = [
|
|
||||||
{"role": "user", "content": full_text},
|
|
||||||
{"role": "assistant", "content": prompt_speech_tokens_str}
|
|
||||||
]
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": "trt_engines_bfloat16",
|
|
||||||
"messages": chat,
|
|
||||||
"max_tokens": 750,
|
|
||||||
"temperature": 0.8,
|
|
||||||
"top_p": 0.95,
|
|
||||||
"top_k": 50,
|
|
||||||
"repetition_penalty": 1.1,
|
|
||||||
"stop": ["<|eos1|>", "<|eos|>"],
|
|
||||||
"stream": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
api_base = "http://localhost:8000/v1/chat/completions"
|
|
||||||
|
|
||||||
buffer = ""
|
|
||||||
async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
|
|
||||||
print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
||||||
print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
||||||
response.raise_for_status()
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if line.startswith("data: "):
|
|
||||||
line_data = line[len("data: "):].strip()
|
|
||||||
if line_data == "[DONE]":
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
json_data = json.loads(line_data)
|
|
||||||
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
|
||||||
if content:
|
|
||||||
buffer += content
|
|
||||||
print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
||||||
while True:
|
|
||||||
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
|
||||||
if not match:
|
|
||||||
break
|
|
||||||
|
|
||||||
token_num = int(match.group(1))
|
|
||||||
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
|
||||||
yield final_id
|
|
||||||
buffer = buffer[match.end():]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Process any remaining complete tokens in the buffer after the stream ends
|
|
||||||
while True:
|
|
||||||
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
|
||||||
if not match:
|
|
||||||
break
|
|
||||||
token_num = int(match.group(1))
|
|
||||||
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
|
||||||
yield final_id
|
|
||||||
buffer = buffer[match.end():]
|
|
||||||
|
|
||||||
|
|
||||||
def forward_audio_tokenizer(self, wav, wav_len):
|
|
||||||
"""Forward pass through the audio tokenizer component.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
wav: Input waveform tensor
|
|
||||||
wav_len: Waveform length tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of global and semantic tokens
|
|
||||||
"""
|
|
||||||
inference_request = pb_utils.InferenceRequest(
|
|
||||||
model_name='audio_tokenizer',
|
|
||||||
requested_output_names=['prompt_speech_tokens'],
|
|
||||||
inputs=[wav, wav_len]
|
|
||||||
)
|
|
||||||
|
|
||||||
inference_response = inference_request.exec()
|
|
||||||
if inference_response.has_error():
|
|
||||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
|
||||||
|
|
||||||
# Extract and convert output tensors
|
|
||||||
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
|
|
||||||
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
|
||||||
|
|
||||||
return prompt_speech_tokens
|
|
||||||
|
|
||||||
def forward_speaker_embedding(self, wav):
|
|
||||||
"""Forward pass through the speaker embedding component.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
wav: Input waveform tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Prompt speaker embedding tensor
|
|
||||||
"""
|
|
||||||
inference_request = pb_utils.InferenceRequest(
|
|
||||||
model_name='speaker_embedding',
|
|
||||||
requested_output_names=['prompt_spk_embedding'],
|
|
||||||
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
|
|
||||||
)
|
|
||||||
|
|
||||||
inference_response = inference_request.exec()
|
|
||||||
if inference_response.has_error():
|
|
||||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
|
||||||
|
|
||||||
# Extract and convert output tensors
|
|
||||||
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
|
|
||||||
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
|
|
||||||
|
|
||||||
return prompt_spk_embedding
|
|
||||||
|
|
||||||
async def forward_token2wav(
|
|
||||||
self,
|
|
||||||
index: int,
|
|
||||||
target_speech_tokens: torch.Tensor,
|
|
||||||
request_id: str,
|
|
||||||
reference_wav: object,
|
|
||||||
reference_wav_len: object,
|
|
||||||
finalize: bool = None) -> torch.Tensor:
|
|
||||||
"""Forward pass through the vocoder component.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt_speech_tokens: Prompt speech tokens tensor
|
|
||||||
prompt_speech_feat: Prompt speech feat tensor
|
|
||||||
prompt_spk_embedding: Prompt spk embedding tensor
|
|
||||||
target_speech_tokens: Target speech tokens tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated waveform tensor
|
|
||||||
"""
|
|
||||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
|
||||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
|
||||||
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
|
||||||
|
|
||||||
# Create and execute inference request
|
|
||||||
inference_request = pb_utils.InferenceRequest(
|
|
||||||
model_name='token2wav_dit',
|
|
||||||
requested_output_names=[
|
|
||||||
"waveform",
|
|
||||||
],
|
|
||||||
inputs=inputs_tensor,
|
|
||||||
request_id=request_id,
|
|
||||||
parameters={"priority": index+1},
|
|
||||||
)
|
|
||||||
|
|
||||||
inference_response = await inference_request.async_exec()
|
|
||||||
if inference_response.has_error():
|
|
||||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
|
||||||
|
|
||||||
# Extract and convert output waveform
|
|
||||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
|
||||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
|
||||||
|
|
||||||
return waveform
|
|
||||||
|
|
||||||
def _extract_speech_feat(self, speech):
|
|
||||||
speech_feat = mel_spectrogram(
|
|
||||||
speech,
|
|
||||||
n_fft=1920,
|
|
||||||
num_mels=80,
|
|
||||||
sampling_rate=24000,
|
|
||||||
hop_size=480,
|
|
||||||
win_size=1920,
|
|
||||||
fmin=0,
|
|
||||||
fmax=8000).squeeze(
|
|
||||||
dim=0).transpose(
|
|
||||||
0,
|
|
||||||
1).to(
|
|
||||||
self.device)
|
|
||||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
|
||||||
return speech_feat
|
|
||||||
|
|
||||||
async def _process_request(self, request):
|
|
||||||
request_id = request.request_id()
|
|
||||||
# Extract input tensors
|
|
||||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
|
||||||
|
|
||||||
# Process reference audio through audio tokenizer
|
|
||||||
if wav is not None:
|
|
||||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
||||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
|
||||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
|
||||||
|
|
||||||
wav_tensor = wav.as_numpy()
|
|
||||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
|
||||||
print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
|
|
||||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
|
||||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
|
||||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
|
||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
|
||||||
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
|
||||||
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
||||||
|
|
||||||
# reference_text = self.default_spk_info["prompt_text"]
|
|
||||||
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
||||||
# prompt_speech_feat = None
|
|
||||||
# prompt_spk_embedding = None
|
|
||||||
|
|
||||||
else:
|
|
||||||
# using pre-cached reference text
|
|
||||||
assert False, "using pre-cached reference text is not supported"
|
|
||||||
reference_text = self.default_spk_info["prompt_text"]
|
|
||||||
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
|
||||||
prompt_speech_feat = None
|
|
||||||
prompt_spk_embedding = None
|
|
||||||
|
|
||||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
|
||||||
target_text = target_text[0][0].decode('utf-8')
|
|
||||||
print(f"target_text: {target_text}, time: {datetime.now()}")
|
|
||||||
|
|
||||||
if self.decoupled:
|
|
||||||
response_sender = request.get_response_sender()
|
|
||||||
|
|
||||||
semantic_token_ids_arr = []
|
|
||||||
token_offset, chunk_index = 0, 0
|
|
||||||
start_time = time.time()
|
|
||||||
this_token_hop_len = self.token_hop_len
|
|
||||||
print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
||||||
async for generated_ids in self.forward_llm_async(
|
|
||||||
target_text=target_text,
|
|
||||||
reference_text=reference_text,
|
|
||||||
prompt_speech_tokens=prompt_speech_tokens,
|
|
||||||
):
|
|
||||||
if not generated_ids:
|
|
||||||
break
|
|
||||||
semantic_token_ids_arr.append(generated_ids)
|
|
||||||
print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
||||||
while True:
|
|
||||||
pending_num = len(semantic_token_ids_arr) - token_offset
|
|
||||||
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
|
||||||
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
|
||||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
||||||
print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
||||||
sub_tts_speech = await self.forward_token2wav(
|
|
||||||
chunk_index,
|
|
||||||
this_tts_speech_token, request_id, wav, wav_len, False
|
|
||||||
)
|
|
||||||
print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
|
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
||||||
response_sender.send(inference_response)
|
|
||||||
|
|
||||||
token_offset += this_token_hop_len
|
|
||||||
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
|
|
||||||
|
|
||||||
if self.dynamic_chunk_strategy == "exponential":
|
|
||||||
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
|
||||||
elif self.dynamic_chunk_strategy == "equal":
|
|
||||||
this_token_hop_len = self.token_hop_len
|
|
||||||
elif self.dynamic_chunk_strategy == "time_based":
|
|
||||||
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
|
||||||
cost_time = time.time() - start_time
|
|
||||||
duration = token_offset / self.token_frame_rate
|
|
||||||
if chunk_index > 0 and cost_time > 0:
|
|
||||||
avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
|
||||||
if avg_chunk_processing_time > 0:
|
|
||||||
multiples = (duration - cost_time) / avg_chunk_processing_time
|
|
||||||
self.logger.log_info(f"multiples: {multiples}")
|
|
||||||
next_pending_num = len(semantic_token_ids_arr) - token_offset
|
|
||||||
if multiples > 4:
|
|
||||||
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
|
||||||
elif multiples > 2:
|
|
||||||
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
|
|
||||||
else:
|
|
||||||
this_token_hop_len = self.token_hop_len
|
|
||||||
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
|
||||||
chunk_index += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
|
||||||
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
|
||||||
response_sender.send(inference_response)
|
|
||||||
|
|
||||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
|
||||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Decoupled mode is not supported")
|
|
||||||
|
|
||||||
async def execute(self, requests):
|
|
||||||
"""Execute inference on the batched requests.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
requests: List of inference requests
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of inference responses containing generated audio
|
|
||||||
"""
|
|
||||||
tasks = [
|
|
||||||
asyncio.create_task(self._process_request(request))
|
|
||||||
for request in requests
|
|
||||||
]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def finalize(self):
|
|
||||||
self.logger.log_info("Finalizing CosyVoice DIT model")
|
|
||||||
if hasattr(self, "http_client"):
|
|
||||||
asyncio.run(self.http_client.aclose())
|
|
||||||
@@ -47,8 +47,6 @@ import requests
|
|||||||
import asyncio
|
import asyncio
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from token2wav import CosyVoice2_Token2Wav
|
|
||||||
|
|
||||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
try:
|
try:
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
@@ -367,7 +365,12 @@ def main(args):
|
|||||||
runner = None
|
runner = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
|
|
||||||
|
if 'Step-Audio-2-mini' in args.token2wav_path:
|
||||||
|
from token2wav_dit import CosyVoice2_Token2Wav
|
||||||
|
else:
|
||||||
|
assert 'CosyVoice2-0.5B' in args.token2wav_path
|
||||||
|
from token2wav import CosyVoice2_Token2Wav
|
||||||
token2wav_model = CosyVoice2_Token2Wav(
|
token2wav_model = CosyVoice2_Token2Wav(
|
||||||
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
|
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
|
||||||
)
|
)
|
||||||
@@ -589,7 +592,6 @@ def main(args):
|
|||||||
t2w_prompt_audios_list,
|
t2w_prompt_audios_list,
|
||||||
t2w_prompt_audios_sample_rate,
|
t2w_prompt_audios_sample_rate,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
|
||||||
token2wav_end_time = time.time()
|
token2wav_end_time = time.time()
|
||||||
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
|
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +1,33 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
cosyvoice_path=/workspace/CosyVoice
|
# cosyvoice_path=/workspace/CosyVoice
|
||||||
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
|
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
|
||||||
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
|
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
|
||||||
|
|
||||||
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
|
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
|
||||||
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
||||||
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
||||||
|
|
||||||
stage=$1
|
stage=$1
|
||||||
stop_stage=$2
|
stop_stage=$2
|
||||||
N_GPUS=2 # set the number of GPUs to use
|
|
||||||
|
|
||||||
|
|
||||||
huggingface_model_local_dir=./cosyvoice2_llm
|
huggingface_model_local_dir=./cosyvoice2_llm
|
||||||
model_scope_model_local_dir=./CosyVoice2-0.5B
|
model_scope_model_local_dir=./CosyVoice2-0.5B
|
||||||
|
step_audio_model_dir=./Step-Audio-2-mini
|
||||||
|
|
||||||
trt_dtype=bfloat16
|
trt_dtype=bfloat16
|
||||||
trt_weights_dir=./trt_weights_${trt_dtype}
|
trt_weights_dir=./trt_weights_${trt_dtype}
|
||||||
trt_engines_dir=./trt_engines_${trt_dtype}
|
trt_engines_dir=./trt_engines_${trt_dtype}
|
||||||
|
|
||||||
model_repo=./model_repo_cosyvoice2_dit
|
model_repo=./model_repo_cosyvoice2_dit
|
||||||
|
bls_instance_num=4
|
||||||
use_spk2info_cache=False
|
|
||||||
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
|
|
||||||
|
echo "Cloning Step-Audio2-mini"
|
||||||
|
git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
|
||||||
|
|
||||||
echo "Cloning CosyVoice"
|
echo "Cloning CosyVoice"
|
||||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
||||||
cd $cosyvoice_path
|
cd $cosyvoice_path
|
||||||
@@ -35,8 +40,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
|
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
|
||||||
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
||||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
||||||
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
|
|
||||||
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
|
echo "Step-Audio2-mini"
|
||||||
|
huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
|
||||||
|
cd $stepaudio2_path/token2wav
|
||||||
|
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||||
|
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
|
||||||
|
cd -
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
@@ -60,40 +70,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
--engine_dir=$trt_engines_dir || exit 1
|
--engine_dir=$trt_engines_dir || exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|
||||||
# echo "Creating model repository"
|
|
||||||
# rm -rf $model_repo
|
|
||||||
# mkdir -p $model_repo
|
|
||||||
# cosyvoice2_dir="cosyvoice2_dit"
|
|
||||||
# token2wav_dir="token2wav_dit"
|
|
||||||
|
|
||||||
# cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
|
||||||
# cp -r ./model_repo/tensorrt_llm $model_repo
|
|
||||||
# cp -r ./model_repo/${token2wav_dir} $model_repo
|
|
||||||
# #if [ $use_spk2info_cache == "False" ]; then
|
|
||||||
# cp -r ./model_repo/audio_tokenizer $model_repo
|
|
||||||
# cp -r ./model_repo/speaker_embedding $model_repo
|
|
||||||
# #fi
|
|
||||||
|
|
||||||
# ENGINE_PATH=$trt_engines_dir
|
|
||||||
# MAX_QUEUE_DELAY_MICROSECONDS=0
|
|
||||||
# MODEL_DIR=$model_scope_model_local_dir
|
|
||||||
# LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
|
||||||
# BLS_INSTANCE_NUM=1
|
|
||||||
# TRITON_MAX_BATCH_SIZE=16
|
|
||||||
# DECOUPLED_MODE=True # True for streaming, False for offline
|
|
||||||
# STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
|
|
||||||
|
|
||||||
# python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
# python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
# python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
|
||||||
# #if [ $use_spk2info_cache == "False" ]; then
|
|
||||||
# python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
# python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
# #fi
|
|
||||||
# fi
|
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
echo "Creating model repository async mode"
|
echo "Creating model repository async mode"
|
||||||
rm -rf $model_repo
|
rm -rf $model_repo
|
||||||
@@ -102,122 +78,75 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
token2wav_dir="token2wav_dit"
|
token2wav_dir="token2wav_dit"
|
||||||
|
|
||||||
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
||||||
cp -r ./model_repo/tensorrt_llm $model_repo
|
|
||||||
cp -r ./model_repo/${token2wav_dir} $model_repo
|
cp -r ./model_repo/${token2wav_dir} $model_repo
|
||||||
#if [ $use_spk2info_cache == "False" ]; then
|
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||||
cp -r ./model_repo/audio_tokenizer $model_repo
|
cp -r ./model_repo/speaker_embedding $model_repo
|
||||||
cp -r ./model_repo/speaker_embedding $model_repo
|
|
||||||
#fi
|
|
||||||
|
|
||||||
ENGINE_PATH=$trt_engines_dir
|
ENGINE_PATH=$trt_engines_dir
|
||||||
MAX_QUEUE_DELAY_MICROSECONDS=0
|
MAX_QUEUE_DELAY_MICROSECONDS=0
|
||||||
MODEL_DIR=$model_scope_model_local_dir
|
MODEL_DIR=$model_scope_model_local_dir
|
||||||
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
||||||
BLS_INSTANCE_NUM=4
|
BLS_INSTANCE_NUM=$bls_instance_num
|
||||||
TRITON_MAX_BATCH_SIZE=1
|
TRITON_MAX_BATCH_SIZE=1
|
||||||
DECOUPLED_MODE=True # True for streaming, False for offline
|
DECOUPLED_MODE=True
|
||||||
STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
|
STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
|
||||||
|
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
#if [ $use_spk2info_cache == "False" ]; then
|
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
#fi
|
|
||||||
rm -rf $model_repo/tensorrt_llm
|
|
||||||
# mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
echo "Starting Triton server on $N_GPUS GPUs"
|
echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
|
||||||
for i in $(seq 0 $(($N_GPUS - 1))); do
|
tritonserver --model-repository $model_repo --http-port 18000 &
|
||||||
echo "Starting server on GPU $i"
|
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 &
|
||||||
http_port=$((19000 + $i))
|
|
||||||
grpc_port=$((18000 + $i))
|
|
||||||
metrics_port=$((17000 + $i))
|
|
||||||
CUDA_VISIBLE_DEVICES=$i tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "Servers are running in the background. Press Ctrl+C to stop them and the script."
|
|
||||||
wait
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
|
||||||
echo "Starting Triton server on $N_GPUS GPUs"
|
|
||||||
N_GPUS=1
|
|
||||||
for i in $(seq 0 $(($N_GPUS - 1))); do
|
|
||||||
echo "Starting server on GPU $i"
|
|
||||||
http_port=$((19000 + $i))
|
|
||||||
grpc_port=$((18000 + $i))
|
|
||||||
metrics_port=$((17000 + $i))
|
|
||||||
CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "Servers are running in the background. Press Ctrl+C to stop them and the script."
|
|
||||||
wait
|
wait
|
||||||
|
# Test using curl
|
||||||
|
# curl http://localhost:8000/v1/chat/completions \
|
||||||
|
# -H "Content-Type: application/json" \
|
||||||
|
# -d '{
|
||||||
|
# "model": "trt_engines_bfloat16",
|
||||||
|
# "messages":[{"role": "user", "content": "Where is New York?"},
|
||||||
|
# {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
|
||||||
|
# "max_tokens": 512,
|
||||||
|
# "temperature": 0.8,
|
||||||
|
# "top_p": 0.95,
|
||||||
|
# "top_k": 50,
|
||||||
|
# "stop": ["<|eos1|>"],
|
||||||
|
# "repetition_penalty": 1.2,
|
||||||
|
# "stream": false
|
||||||
|
# }'
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
echo "Single request test http, only work for offline TTS mode"
|
echo "Running benchmark client"
|
||||||
python3 client_http.py \
|
num_task=4
|
||||||
--reference-audio ./assets/prompt_audio.wav \
|
mode=streaming
|
||||||
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
BLS_INSTANCE_NUM=$bls_instance_num
|
||||||
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
|
|
||||||
--model-name cosyvoice2
|
python3 client_grpc.py \
|
||||||
|
--server-addr localhost \
|
||||||
|
--server-port 8001 \
|
||||||
|
--model-name cosyvoice2_dit \
|
||||||
|
--num-tasks $num_task \
|
||||||
|
--mode $mode \
|
||||||
|
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||||
|
--log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
echo "Running benchmark client grpc on $N_GPUS GPUs"
|
echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
|
||||||
num_task=1
|
|
||||||
|
|
||||||
mode=streaming
|
|
||||||
BLS_INSTANCE_NUM=4
|
|
||||||
|
|
||||||
for i in $(seq 0 $(($N_GPUS - 1))); do
|
|
||||||
grpc_port=$((18000 + $i))
|
|
||||||
echo "Running client for server on localhost:$grpc_port"
|
|
||||||
python3 client_grpc.py \
|
|
||||||
--server-addr localhost \
|
|
||||||
--server-port $grpc_port \
|
|
||||||
--model-name cosyvoice2_dit \
|
|
||||||
--num-tasks $num_task \
|
|
||||||
--mode $mode \
|
|
||||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
|
||||||
--log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_gpu${i} &
|
|
||||||
done
|
|
||||||
wait
|
|
||||||
fi
|
|
||||||
if [ $stage -le 50 ] && [ $stop_stage -ge 50 ]; then
|
|
||||||
echo "Running benchmark client grpc on $N_GPUS GPUs"
|
|
||||||
num_task=4
|
|
||||||
N_GPUS=1
|
|
||||||
mode=streaming
|
|
||||||
BLS_INSTANCE_NUM=4
|
|
||||||
|
|
||||||
for i in $(seq 0 $(($N_GPUS - 1))); do
|
|
||||||
grpc_port=$((18000 + $i))
|
|
||||||
echo "Running client for server on localhost:$grpc_port"
|
|
||||||
python3 client_grpc.py \
|
|
||||||
--server-addr localhost \
|
|
||||||
--server-port $grpc_port \
|
|
||||||
--model-name cosyvoice2_dit \
|
|
||||||
--num-tasks $num_task \
|
|
||||||
--mode $mode \
|
|
||||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
|
||||||
--log-dir ./log_single_card_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} &
|
|
||||||
done
|
|
||||||
wait
|
|
||||||
fi
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
echo "stage 6: Offline inference benchmark"
|
|
||||||
n_gpus=1
|
|
||||||
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
|
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
|
||||||
backend=trtllm-serve # hf, trtllm, vllm
|
backend=trtllm # hf, trtllm, vllm, trtllm-serve
|
||||||
|
|
||||||
batch_sizes=(16 8 4 2 1)
|
batch_sizes=(16)
|
||||||
batch_sizes=(16 8 4 2)
|
|
||||||
token2wav_batch_size=1
|
token2wav_batch_size=1
|
||||||
|
|
||||||
for batch_size in ${batch_sizes[@]}; do
|
for batch_size in ${batch_sizes[@]}; do
|
||||||
for dataset in ${datasets[@]}; do
|
for dataset in ${datasets[@]}; do
|
||||||
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
|
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
|
||||||
@@ -225,7 +154,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
python3 offline_inference.py \
|
python3 offline_inference.py \
|
||||||
--output-dir $output_dir \
|
--output-dir $output_dir \
|
||||||
--llm-model-name-or-path $huggingface_model_local_dir \
|
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||||
--token2wav-path $model_scope_model_local_dir \
|
--token2wav-path $step_audio_model_dir/token2wav \
|
||||||
--backend $backend \
|
--backend $backend \
|
||||||
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||||
--engine-dir $trt_engines_dir \
|
--engine-dir $trt_engines_dir \
|
||||||
@@ -234,34 +163,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
|
||||||
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential
|
# Note: Using pre-computed cosyvoice2 tokens
|
||||||
|
python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
|
||||||
|
# Offline Token2wav inference
|
||||||
|
# python3 token2wav_dit.py --enable-trt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|
||||||
CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4
|
|
||||||
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|
||||||
#! /usr/bin/env bash
|
|
||||||
curl http://localhost:8000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "trt_engines_bfloat16",
|
|
||||||
"messages":[{"role": "user", "content": "Where is New York?"},
|
|
||||||
{"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
|
|
||||||
"max_tokens": 512,
|
|
||||||
"temperature": 0.8,
|
|
||||||
"top_p": 0.95,
|
|
||||||
"top_k": 50,
|
|
||||||
"stop": ["<|eos1|>"],
|
|
||||||
"repetition_penalty": 1.2,
|
|
||||||
"stream": false
|
|
||||||
}'
|
|
||||||
fi
|
|
||||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
|||||||
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
||||||
|
|
||||||
flow_pre_lookahead_len = 3
|
flow_pre_lookahead_len = 3
|
||||||
CHUNK_SIZE = 15
|
CHUNK_SIZE = 25
|
||||||
token_frame_rate = 25
|
token_frame_rate = 25
|
||||||
OVERLAP_SIZE = 0
|
OVERLAP_SIZE = 0
|
||||||
|
|
||||||
@@ -67,20 +67,12 @@ if __name__ == "__main__":
|
|||||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
|
||||||
|
|
||||||
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
|
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
|
||||||
# if id != "unseen3_text5":
|
|
||||||
# continue
|
|
||||||
# else:
|
|
||||||
# a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt")
|
|
||||||
# generated_speech_tokens = a["semantic_token_ids_arr"]
|
|
||||||
# print(generated_speech_tokens)
|
|
||||||
assert prompt_audio_sample_rate == 16000
|
assert prompt_audio_sample_rate == 16000
|
||||||
|
|
||||||
prompt_text = prompt_text_list[0]
|
prompt_text = prompt_text_list[0]
|
||||||
prompt_speech_tokens = prompt_speech_tokens_list[0]
|
prompt_speech_tokens = prompt_speech_tokens_list[0]
|
||||||
|
|
||||||
|
|
||||||
# generated_ids_iter = fake_generated_id_iter(generated_speech_tokens)
|
|
||||||
|
|
||||||
semantic_token_ids_arr, token_offset = [], 0
|
semantic_token_ids_arr, token_offset = [], 0
|
||||||
flow_prompt_speech_token_len = len(prompt_speech_tokens)
|
flow_prompt_speech_token_len = len(prompt_speech_tokens)
|
||||||
|
|
||||||
@@ -114,14 +106,16 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
audios = output_wavs
|
audios = output_wavs
|
||||||
reconstructed_audio = np.concatenate(audios)
|
reconstructed_audio = np.concatenate(audios)
|
||||||
# Save reconstructed audio
|
|
||||||
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
||||||
|
|
||||||
|
|
||||||
print(f"Saved {id}")
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
if _ == 0:
|
if _ == 0:
|
||||||
token2wav_model.speaker_cache = {}
|
token2wav_model.speaker_cache = {}
|
||||||
print(f"Warmup time: {end_time - start_time} seconds")
|
print(f"Warmup time: {end_time - start_time} seconds")
|
||||||
print(f"Total forward count: {total_forward_count}")
|
print("clear speaker cache")
|
||||||
|
elif _ == 1:
|
||||||
|
print(f"Cost time without speaker cache: {end_time - start_time} seconds")
|
||||||
|
else:
|
||||||
|
print(f"Cost time with speaker cache: {end_time - start_time} seconds")
|
||||||
|
print(f"Total flow matching forward calls: {total_forward_count}")
|
||||||
Reference in New Issue
Block a user