add streaming dit

This commit is contained in:
yuekaiz
2025-09-24 15:18:01 +08:00
parent 444b7ff5df
commit 482464ea27
10 changed files with 850 additions and 269 deletions

View File

@@ -43,6 +43,9 @@ import soundfile as sf
import s3tokenizer
from functools import partial
import time
import requests
import asyncio
import httpx
from token2wav import CosyVoice2_Token2Wav
@@ -53,6 +56,32 @@ except RuntimeError:
pass
async def send_request_async(client, url, payload):
response = await client.post(url, json=payload, timeout=None)
response.raise_for_status()
response_json = response.json()
return response_json['choices'][0]['message']['content']
async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
async with httpx.AsyncClient() as client:
tasks = []
for chat in chats:
payload = {
"model": model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
tasks.append(send_request_async(client, api_base, payload))
return await asyncio.gather(*tasks)
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
@@ -149,7 +178,7 @@ def get_args():
"--backend",
type=str,
default="hf",
choices=["hf", "trtllm", "vllm"],
choices=["hf", "trtllm", "vllm", "trtllm-serve"],
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
)
parser.add_argument(
@@ -164,6 +193,18 @@ def get_args():
default=0.6,
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
)
parser.add_argument(
"--openai-api-base",
type=str,
default="http://localhost:8000/v1/chat/completions",
help="OpenAI API base URL (for trtllm-serve backend)",
)
parser.add_argument(
"--openai-model-name",
type=str,
default="trt_engines_bfloat16",
help="Model name to use with OpenAI API (for trtllm-serve backend)",
)
args = parser.parse_args()
return args
@@ -180,6 +221,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
prompt_text_after_apply_template_list = []
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
chat_list = []
for _, item in enumerate(batch):
audio_processing_start_time = time.time()
prompt_text, target_text = (
@@ -237,6 +279,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
{"role": "user", "content": full_text_list[i]},
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
]
chat_list.append(chat)
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
@@ -265,6 +308,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
"audio_processing_time": total_audio_processing_time,
"speech_tokenization_time": total_speech_tokenization_time,
"text_tokenization_time": total_text_tokenization_time,
"chat_list": chat_list
}
@@ -318,6 +362,9 @@ def main(args):
elif args.backend == "vllm":
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
runner = None
elif args.backend == "trtllm-serve":
model = None
runner = None
else:
raise ValueError(f"Unsupported backend: {args.backend}")
@@ -452,6 +499,35 @@ def main(args):
print(outputs)
for j, output in enumerate(outputs):
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
elif args.backend == "trtllm-serve":
if args.batch_size > 1:
outputs = asyncio.run(send_batch_requests_async(
args.openai_api_base,
args.openai_model_name,
batch["chat_list"],
args.temperature,
args.top_p,
args.top_k,
))
else:
outputs = []
for i, chat in enumerate(batch["chat_list"]):
payload = {
"model": args.openai_model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
response = requests.post(args.openai_api_base, json=payload)
response.raise_for_status()
response_json = response.json()
generated_content = response_json['choices'][0]['message']['content']
outputs.append(generated_content)
llm_end_time = time.time()
total_llm_time += (llm_end_time - llm_start_time)
@@ -459,10 +535,21 @@ def main(args):
items_for_token_2wav = []
for i in range(len(batch["ids"])):
llm_post_processing_start_time = time.time()
input_length = len(batch["input_ids"][i])
generated_ids = outputs[i][input_length:]
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
speech_ids = extract_speech_ids(speech_tokens_str)
if args.backend == "trtllm-serve":
speech_tokens_str = outputs[i].strip().split('><')
if len(speech_tokens_str) > 1:
speech_tokens_str = [
t if t.startswith('<') else '<' + t for t in speech_tokens_str
]
speech_tokens_str = [
t if t.endswith('>') else t + '>' for t in speech_tokens_str
]
speech_ids = extract_speech_ids(speech_tokens_str)
else:
input_length = len(batch["input_ids"][i])
generated_ids = outputs[i][input_length:]
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
speech_ids = extract_speech_ids(speech_tokens_str)
print(i, speech_ids)
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
@@ -558,6 +645,8 @@ if __name__ == "__main__":
from tensorrt_llm.runtime import ModelRunnerCpp
elif args.backend == "hf":
from transformers import AutoModelForCausalLM
elif args.backend == "trtllm-serve":
pass
else:
raise ValueError(f"Unsupported backend: {args.backend}")
main(args)