diff --git a/runtime/triton_trtllm/README.md b/runtime/triton_trtllm/README.md index 420990e..3765038 100644 --- a/runtime/triton_trtllm/README.md +++ b/runtime/triton_trtllm/README.md @@ -78,7 +78,7 @@ For offline inference mode benchmark, please check the below command: # install FlashCosyVoice for token2wav batching # git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt # cd /workspace/FlashCosyVoice -# pip install -e . +# pip install -e . # cd - # wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx @@ -116,7 +116,7 @@ The following results were obtained by decoding on a single L20 GPU with 26 prom | HF | 1 | 39.26 | 44.31 | 0.2494 | | HF | 2 | 30.54 | 35.62 | 0.2064 | | HF | 4 | 18.63 | 23.90 | 0.1421 | -| HF | 8 | 11.22 | 16.45 | 0.0947 | +| HF | 8 | 11.22 | 16.45 | 0.0947 | | HF | 16 | 8.42 | 13.78 | 0.0821 | | TRTLLM | 1 | 12.46 | 17.31 | 0.0987 | | TRTLLM | 2 | 7.64 |12.65 | 0.0739 | diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index 523cd56..853aefe 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -65,6 +65,7 @@ def extract_speech_ids(speech_tokens_str): print(f"Unexpected token: {token_str}") return speech_ids + def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens): """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>""" speech_id_str = "" @@ -167,7 +168,6 @@ def get_args(): return args - def data_collator(batch, tokenizer, s3_tokenizer): """Simplified data collator for batch_size=1 processing""" collator_start_time = time.time() @@ -202,7 +202,6 @@ def data_collator(batch, tokenizer, s3_tokenizer): item["prompt_audio"]["sampling_rate"], ) ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0) - # ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True) print(ref_audio_org.shape) if ref_sr != target_sample_rate: @@ -220,7 +219,6 @@ def data_collator(batch, tokenizer, s3_tokenizer): prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"] prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens) else: - # convert to float first mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0))) if len(mels) > 0: @@ -287,33 +285,23 @@ def main(args): os.makedirs(args.output_dir, exist_ok=True) assert torch.cuda.is_available() - # world_size, local_rank, rank = init_distributed() local_rank, world_size, rank = 0, 1, 0 device = torch.device(f"cuda:{local_rank}") - # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path) - # model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) - # Initialize backend based on argument if args.backend == "hf": - # Load HuggingFace model model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path) model.eval() model.to(device) runner = None elif args.backend == "trtllm": - # Validate engine_dir is provided if args.engine_dir is None: raise ValueError("--engine-dir is required when backend is 'trtllm'") - # import tensorrt_llm - #from tensorrt_llm.runtime import ModelRunnerCpp - # Initialize TensorRT-LLM runner runtime_rank = tensorrt_llm.mpi_rank() model = None - # Prepare input for runner initialization runner_kwargs = dict( engine_dir=args.engine_dir, rank=runtime_rank, @@ -328,7 +316,6 @@ def main(args): runner = ModelRunnerCpp.from_dir(**runner_kwargs) elif args.backend == "vllm": - # from vllm import LLM, SamplingParams model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) runner = None else: @@ -349,7 +336,6 @@ def main(args): trust_remote_code=True, ) - # sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) sampler = None dataloader = DataLoader( dataset, @@ -385,7 +371,6 @@ def main(args): total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"] total_text_tokenization_time_in_collator += batch["text_tokenization_time"] with torch.no_grad(): - # Generate speech tokens using LLM llm_start_time = time.time() if args.backend == "hf": input_ids_list = batch["input_ids"] @@ -393,31 +378,22 @@ def main(args): input_ids = input_ids_list[0].unsqueeze(0) attention_mask = torch.ones_like(input_ids) else: - # Handle batch > 1 if needed max_len = max([len(input_ids) for input_ids in input_ids_list]) - # input_ids_list_new = [ - # torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids]) - # for input_ids in input_ids_list - # ] input_ids_list_new = [ torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)]) for input_ids in input_ids_list ] input_ids = torch.stack(input_ids_list_new) - # compute attention mask attention_mask = torch.zeros_like(input_ids) for i in range(len(input_ids_list)): attention_mask[i, :len(input_ids_list[i])] = 1 - # breakpoint() - - input_ids = input_ids.to(device) outputs = model.generate( input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), - max_new_tokens=2048, # Max length for generation + max_new_tokens=2048, do_sample=True, top_p=args.top_p, temperature=args.temperature, @@ -426,14 +402,11 @@ def main(args): ) torch.cuda.synchronize() elif args.backend == "trtllm": - # Convert input_ids to list of tensors for TensorRT-LLM batch_input_ids = [ids for ids in batch["input_ids"]] input_lengths = [x.size(0) for x in batch_input_ids] - # Get end_id from tokenizer end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================") - # random_seed=42, repetition_penalty=1.1, outputs = runner.generate( batch_input_ids=batch_input_ids, max_new_tokens=2048, @@ -451,7 +424,6 @@ def main(args): return_all_generated_tokens=False ) torch.cuda.synchronize() - # Extract output_ids from TensorRT-LLM output output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"] num_output_sents, num_beams, _ = output_ids.size() assert num_beams == 1 @@ -463,18 +435,12 @@ def main(args): for i in range(batch_size * num_return_sequences): batch_idx = i // num_return_sequences seq_idx = i % num_return_sequences - # inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist() - # input_text = tokenizer.decode(inputs) - # print(f'Input [Text {batch_idx}]: \"{input_text}\"') output_begin = input_lengths[batch_idx] output_end = sequence_lengths[i][beam] - # outputs_i = output_ids[i][beam][output_begin:output_end].tolist() outputs_i = output_ids[i][beam][:output_end].tolist() outputs.append(outputs_i) elif args.backend == "vllm": input_ids_list = [ids.tolist() for ids in batch["input_ids"]] - # prompts = [batch["prompt_text_after_apply_template"][i] for i in range(len(batch["prompt_text_after_apply_template"]))] - # print(prompts) sampling_params = SamplingParams( temperature=args.temperature, top_p=args.top_p, @@ -483,26 +449,21 @@ def main(args): max_tokens=2048, ) outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params) - # outputs = model.generate(prompts, sampling_params) print(outputs) - # breakpoint() for j, output in enumerate(outputs): outputs[j] = input_ids_list[j] + output.outputs[0].token_ids llm_end_time = time.time() total_llm_time += (llm_end_time - llm_start_time) - items_for_token2wav = [] + items_for_token_2wav = [] for i in range(len(batch["ids"])): llm_post_processing_start_time = time.time() - # Extract generated tokens (excluding input) input_length = len(batch["input_ids"][i]) - generated_ids = outputs[i][input_length:] # Remove last token if needed + generated_ids = outputs[i][input_length:] speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - # Extract speech IDs from token strings like <|s_23456|> speech_ids = extract_speech_ids(speech_tokens_str) print(i, speech_ids) - # breakpoint() if len(speech_ids) == 0: print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") continue @@ -517,7 +478,7 @@ def main(args): llm_post_processing_end_time = time.time() total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time if current_prompt_audio is not None: - items_for_token2wav.append({ + items_for_token_2wav.append({ "speech_ids": speech_ids, "prompt_audio": current_prompt_audio.squeeze(0), "id": batch["ids"][i] @@ -525,8 +486,8 @@ def main(args): else: print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping") - for i in range(0, len(items_for_token2wav), args.token2wav_batch_size): - t2w_batch = items_for_token2wav[i:i + args.token2wav_batch_size] + for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size): + t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size] if not t2w_batch: continue @@ -535,7 +496,6 @@ def main(args): t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch) t2w_ids = [item["id"] for item in t2w_batch] - # Generate audio using CosyVoice2 token2wav_start_time = time.time() generated_wavs = token2wav_model( t2w_generated_speech_tokens_list, @@ -547,7 +507,6 @@ def main(args): total_token2wav_time += (token2wav_end_time - token2wav_start_time) audio_save_start_time = time.time() - # Convert to numpy and save for j, audio_hat in enumerate(generated_wavs): generated_wave = audio_hat.squeeze().cpu().numpy() total_audio_samples += len(generated_wave) @@ -571,7 +530,6 @@ def main(args): log_file_path = os.path.join(args.output_dir, "log.txt") with open(log_file_path, 'w') as f: - # Convert Namespace to dict for JSON serialization args_dict = vars(args) log_data = { "args": args_dict, @@ -602,4 +560,4 @@ if __name__ == "__main__": from transformers import AutoModelForCausalLM else: raise ValueError(f"Unsupported backend: {args.backend}") - main(args) \ No newline at end of file + main(args) diff --git a/runtime/triton_trtllm/token2wav.py b/runtime/triton_trtllm/token2wav.py index 786c582..86e4625 100644 --- a/runtime/triton_trtllm/token2wav.py +++ b/runtime/triton_trtllm/token2wav.py @@ -70,6 +70,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): f.write(engine_bytes) logging.info("Succesfully convert onnx to trt...") + class TrtContextWrapper: def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) @@ -88,12 +89,13 @@ class TrtContextWrapper: def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) + class CosyVoice2_Token2Wav(torch.nn.Module): def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0): super().__init__() self.device_id = device_id self.device = f"cuda:{device_id}" - + self.flow = CausalMaskedDiffWithXvec() self.flow.half() self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) @@ -107,22 +109,20 @@ class CosyVoice2_Token2Wav(torch.nn.Module): option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 - self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, - providers=["CPUExecutionProvider"]) - + self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) + self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval() - gpu="l20" + gpu = "l20" if enable_trt: self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan', - f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', - 1, - True) + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + True) self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', - f'{model_dir}/campplus.onnx', - 1, - False) - + f'{model_dir}/campplus.onnx', + 1, + False) def forward_spk_embedding(self, spk_feat): if isinstance(self.spk_model, onnxruntime.InferenceSession): @@ -173,7 +173,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: - trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=2, max_batch_size=16) + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16) convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16) del self.flow.decoder.estimator import tensorrt as trt @@ -182,10 +182,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module): assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) - def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64): + def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64): min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] - opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] - max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)] + max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), + (max_batch_size * 2, 80)] input_names = ["x", "mask", "mu", "cond", "t", "spks"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} @@ -203,7 +204,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() prompt_speech_tokens_list.append(speech_tokens_i) return prompt_speech_tokens_list - + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: spk_emb_for_flow = [] for audio in prompt_audios_list: @@ -213,9 +214,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb = self.forward_spk_embedding(spk_feat) spk_emb_for_flow.append(spk_emb) - spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) return spk_emb_for_flow - + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): prompt_mels_for_flow = [] prompt_mels_lens_for_flow = [] @@ -231,9 +232,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) return prompt_mels_for_flow, prompt_mels_lens_for_flow - - def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, + prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): batch_size = prompt_mels_for_flow.shape[0] flow_inputs = [] flow_inputs_lens = [] @@ -262,14 +263,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module): generated_wavs.append(wav) return generated_wavs - @torch.inference_mode() def forward( self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): # assert all item in prompt_audios_sample_rate is 16000 assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) @@ -277,10 +276,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) - generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + generated_mels, generated_mels_lens = self.forward_flow( + prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) - + return generated_wavs @@ -288,13 +288,14 @@ def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] for i, item in enumerate(batch): generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() + audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) ids.append(item['id']) return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--enable-trt", action="store_true") @@ -305,6 +306,7 @@ def get_args(): parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") return parser.parse_args() + if __name__ == "__main__": args = get_args() model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) @@ -315,22 +317,19 @@ if __name__ == "__main__": dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) - data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - - + for epoch in range(args.warmup): start_time = time.time() - + for batch in data_loader: ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) - for id, wav in zip(ids, generated_wavs): torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) - + end_time = time.time() epoch_time = end_time - start_time - print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")