diff --git a/examples/grpo/cosyvoice2/infer_dataset.py b/examples/grpo/cosyvoice2/infer_dataset.py index 4dcbc96..f0d22d7 100644 --- a/examples/grpo/cosyvoice2/infer_dataset.py +++ b/examples/grpo/cosyvoice2/infer_dataset.py @@ -53,7 +53,7 @@ except RuntimeError: pass -TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" +TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501 def audio_decode_cosyvoice2( diff --git a/examples/grpo/cosyvoice2/pretrained_to_huggingface.py b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py index 161a11f..7aaa10b 100644 --- a/examples/grpo/cosyvoice2/pretrained_to_huggingface.py +++ b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # diff --git a/examples/grpo/cosyvoice2/scripts/offline-decode-files.py b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py index 847d434..90c4665 100644 --- a/examples/grpo/cosyvoice2/scripts/offline-decode-files.py +++ b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 -# # Copyright (c) 2023 by manyeyes # Copyright (c) 2023 Xiaomi Corporation @@ -195,7 +193,7 @@ def write_error_stats( hyp = list("".join(hyp)) results[i] = (cut_id, ref, hyp) - for cut_id, ref, hyp in results: + for _cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) for ref_word, hyp_word in ali: if ref_word == ERR: diff --git a/examples/grpo/cosyvoice2/token2wav_asr_server.py b/examples/grpo/cosyvoice2/token2wav_asr_server.py index 8a6cb6e..9f9f80b 100644 --- a/examples/grpo/cosyvoice2/token2wav_asr_server.py +++ b/examples/grpo/cosyvoice2/token2wav_asr_server.py @@ -295,7 +295,7 @@ def main(): metrics_port=8002, ) - device_ids = [i for i in range(args.number_of_devices)] + device_ids = list(range(args.number_of_devices)) device_ids = device_ids * args.number_of_instances_per_device with Triton(config=triton_config) as triton: diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 840390d..b344849 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -122,7 +122,10 @@ def write_triton_stats(stats, summary_file): total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 summary_f.write( - f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" + f"queue time {total_queue_time_s:<5.2f} s, " + f"compute infer time {total_infer_time_s:<5.2f} s, " + f"compute input time {total_input_time_s:<5.2f} s, " + f"compute output time {total_output_time_s:<5.2f} s \n" ) model_batch_stats = model_state["batch_stats"] for batch in model_batch_stats: @@ -136,7 +139,12 @@ def write_triton_stats(stats, summary_file): compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6 summary_f.write( - f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, " + f"total_infer_time {compute_infer_time_ms:<9.2f} ms, " + f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}=" + f"{compute_infer_time_ms / batch_count:.2f} ms, " + f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}=" + f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" ) summary_f.write( f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " diff --git a/runtime/triton_trtllm/client_http.py b/runtime/triton_trtllm/client_http.py index 4d73e0b..2c30a7a 100644 --- a/runtime/triton_trtllm/client_http.py +++ b/runtime/triton_trtllm/client_http.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import requests import soundfile as sf -import json import numpy as np import argparse diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index 97659ad..a2bfb30 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -25,12 +25,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json -import math import os -import re import threading import time -from typing import Dict, List, Tuple, Optional, Union import numpy as np import torch diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 8e2b28b..2e2533c 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -178,7 +178,6 @@ class TritonPythonModel: yield final_id buffer = buffer[match.end():] - def forward_audio_tokenizer(self, wav, wav_len): """Forward pass through the audio tokenizer component. @@ -263,7 +262,7 @@ class TritonPythonModel: ], inputs=inputs_tensor, request_id=request_id, - parameters={"priority": index+1}, + parameters={"priority": index + 1}, ) inference_response = await inference_request.async_exec() diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index 10bc272..1e38052 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -28,7 +28,6 @@ import json import os import logging -from typing import List, Dict import torch from torch.utils.dlpack import to_dlpack diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 1f90644..f9e461e 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -48,9 +48,11 @@ import hashlib logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) + ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) + def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: """ Generates a unique ID for a torch.Tensor. @@ -65,6 +67,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: return hasher.hexdigest() + class TritonPythonModel: """Triton Python model for vocoder. @@ -114,7 +117,6 @@ class TritonPythonModel: request_id = request.request_id() - wav_array = pb_utils.get_input_tensor_by_name( request, "reference_wav").as_numpy() wav_len = pb_utils.get_input_tensor_by_name( @@ -125,7 +127,10 @@ class TritonPythonModel: spk_id = get_spk_id_from_prompt_audio(wav) - audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) + audio_hat = self.token2wav_model.forward_streaming( + target_speech_tokens, finalize, request_id=request_id, + speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000 + ) outputs = [] diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index d413003..6bce5cc 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -35,7 +35,7 @@ import numpy as np from hyperpyyaml import load_hyperpyyaml -def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): +def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor): """perform fade_in_out in tensor style """ mel_overlap_len = int(window.shape[0] / 2) @@ -45,6 +45,7 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] return fade_in_mel + def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): import tensorrt as trt logging.info("Converting onnx to trt...") @@ -90,6 +91,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): f.write(engine_bytes) logging.info("Succesfully convert onnx to trt...") + class TrtContextWrapper: def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) @@ -108,6 +110,7 @@ class TrtContextWrapper: def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) + class CosyVoice2_Token2Wav(torch.nn.Module): def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): super().__init__() @@ -131,27 +134,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module): option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 - self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, - providers=["CPUExecutionProvider"]) + self.spk_model = onnxruntime.InferenceSession( + f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() - gpu="l20" + gpu = "l20" if enable_trt: if streaming: - self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', - f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', - 1, - self.dtype, streaming) + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', + 1, + self.dtype, streaming + ) else: - self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', - f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', - 1, - self.dtype) - self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', - f'{model_dir}/campplus.onnx', - 1, - False) - + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + self.dtype + ) + self.load_spk_trt( + f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False + ) self.streaming_flow_cache = {} self.speaker_cache = {} @@ -215,7 +224,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): opt_batch_size = 2 max_batch_size = 16 if streaming: - opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts + opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) del self.flow.decoder.estimator @@ -228,13 +237,27 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): if streaming: min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] - opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)] - max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2), + (16, opt_batch_size * 2, 8, 100, 128) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2), + (16, max_batch_size * 2, 8, 1000, 128) + ] input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] else: min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] - opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] - max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80) + ] input_names = ["x", "mask", "mu", "cond", "t", "spks"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} @@ -279,11 +302,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module): mel_len = mel.shape[0] prompt_mels_for_flow.append(mel) prompt_mels_lens_for_flow.append(mel_len) - prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence( + prompt_mels_for_flow, batch_first=True, padding_value=0 + ) # [B, T', num_mels=80] prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) return prompt_mels_for_flow, prompt_mels_lens_for_flow - def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], + generated_speech_tokens_list: list[list[int]], + prompt_mels_for_flow: torch.Tensor, + prompt_mels_lens_for_flow: torch.Tensor, + spk_emb_for_flow: torch.Tensor): batch_size = prompt_mels_for_flow.shape[0] flow_inputs = [] flow_inputs_lens = [] @@ -311,7 +340,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): generated_wavs.append(wav) return generated_wavs - @torch.inference_mode() def forward( self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] @@ -320,7 +348,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) - generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + generated_mels, generated_mels_lens = self.forward_flow( + prompt_speech_tokens_list, generated_speech_tokens_list, + prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) return generated_wavs @@ -337,7 +368,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow - def get_prompt_audio_cache_for_streaming_tts( self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow ): @@ -356,7 +386,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] return new_cache - @torch.inference_mode() def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 @@ -379,9 +408,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module): if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( - mel = torch.zeros(1, 80, 0, device='cuda'), - source = torch.zeros(1, 1, 0, device='cuda'), - speech = torch.zeros(1, 0, device='cuda'), + mel=torch.zeros(1, 80, 0, device='cuda'), + source=torch.zeros(1, 1, 0, device='cuda'), + speech=torch.zeros(1, 0, device='cuda'), ) current_request_cache = self.streaming_flow_cache[request_id] @@ -389,7 +418,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( token=generated_speech_tokens, spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), @@ -400,15 +428,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self.streaming_flow_cache[request_id] = new_streaming_flow_cache - if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) - - hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() @@ -422,9 +447,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module): # update vocoder cache self.hift_cache_dict[request_id] = dict( - mel = mel[..., -self.mel_cache_len:].clone().detach(), - source = source[:, :, -self.source_cache_len:].clone().detach(), - speech = speech[:, -self.source_cache_len:].clone().detach(), + mel=mel[..., -self.mel_cache_len:].clone().detach(), + source=source[:, :, -self.source_cache_len:].clone().detach(), + speech=speech[:, -self.source_cache_len:].clone().detach(), ) if not last_chunk: speech = speech[:, :-self.source_cache_len] @@ -436,6 +461,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): return speech + def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] for i, item in enumerate(batch): @@ -447,6 +473,7 @@ def collate_fn(batch): return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--enable-trt", action="store_true") @@ -457,6 +484,7 @@ def get_args(): parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") return parser.parse_args() + if __name__ == "__main__": args = get_args() model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) @@ -466,22 +494,17 @@ if __name__ == "__main__": dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) - data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - for epoch in range(args.warmup): start_time = time.time() - for batch in data_loader: ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) - for id, wav in zip(ids, generated_wavs): torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) - end_time = time.time() epoch_time = end_time - start_time - print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index 77f2915..e3eac2f 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -28,7 +28,6 @@ import argparse import json import os import sys -from pathlib import Path import torch import torch.distributed as dist diff --git a/runtime/triton_trtllm/scripts/test_llm.py b/runtime/triton_trtllm/scripts/test_llm.py index d52d724..90e6710 100644 --- a/runtime/triton_trtllm/scripts/test_llm.py +++ b/runtime/triton_trtllm/scripts/test_llm.py @@ -15,11 +15,6 @@ # limitations under the License. import argparse -import ast -import csv -import os -from pathlib import Path -from typing import List, Optional import numpy as np import torch diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index e9c2ebb..9c4a2fb 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -9,6 +9,7 @@ import time from token2wav_dit import CosyVoice2_Token2Wav import soundfile as sf + def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] prompt_speech_tokens_list, prompt_text_list = [], [] @@ -23,6 +24,7 @@ def collate_fn(batch): return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--enable-trt", action="store_true") @@ -79,7 +81,11 @@ if __name__ == "__main__": this_chunk_size = token_frame_rate * (2 ** chunk_index) if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len: - wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + wavs = token2wav_model.forward_streaming( + buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], + False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, + prompt_audio_sample_rate=prompt_audio_sample_rate + ) buffer = buffer[this_chunk_size - OVERLAP_SIZE:] output_wavs.append(wavs) @@ -87,7 +93,10 @@ if __name__ == "__main__": chunk_index += 1 else: - wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + wavs = token2wav_model.forward_streaming( + buffer, True, request_id=id, speaker_id=f"{id}", + prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate + ) output_wavs.append(wavs) total_forward_count += 1 # chunk_index += 1 @@ -96,7 +105,6 @@ if __name__ == "__main__": for i, wav in enumerate(output_wavs): output_wavs[i] = wav.cpu().numpy().squeeze() - audios = output_wavs reconstructed_audio = np.concatenate(audios) sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") @@ -111,4 +119,4 @@ if __name__ == "__main__": print(f"Cost time without speaker cache: {end_time - start_time} seconds") else: print(f"Cost time with speaker cache: {end_time - start_time} seconds") - print(f"Total flow matching forward calls: {total_forward_count}") \ No newline at end of file + print(f"Total flow matching forward calls: {total_forward_count}")