mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add streaming dit
This commit is contained in:
@@ -43,6 +43,9 @@ import soundfile as sf
|
||||
import s3tokenizer
|
||||
from functools import partial
|
||||
import time
|
||||
import requests
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
from token2wav import CosyVoice2_Token2Wav
|
||||
|
||||
@@ -53,6 +56,32 @@ except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def send_request_async(client, url, payload):
|
||||
response = await client.post(url, json=payload, timeout=None)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
return response_json['choices'][0]['message']['content']
|
||||
|
||||
|
||||
async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
|
||||
async with httpx.AsyncClient() as client:
|
||||
tasks = []
|
||||
for chat in chats:
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": chat,
|
||||
"max_tokens": 2048,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"repetition_penalty": 1.1,
|
||||
"stop": ["<|eos1|>", "<|eos|>"],
|
||||
"stream": False,
|
||||
}
|
||||
tasks.append(send_request_async(client, api_base, payload))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def extract_speech_ids(speech_tokens_str):
|
||||
"""Extract speech IDs from token strings like <|s_23456|>"""
|
||||
speech_ids = []
|
||||
@@ -149,7 +178,7 @@ def get_args():
|
||||
"--backend",
|
||||
type=str,
|
||||
default="hf",
|
||||
choices=["hf", "trtllm", "vllm"],
|
||||
choices=["hf", "trtllm", "vllm", "trtllm-serve"],
|
||||
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -164,6 +193,18 @@ def get_args():
|
||||
default=0.6,
|
||||
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--openai-api-base",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1/chat/completions",
|
||||
help="OpenAI API base URL (for trtllm-serve backend)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--openai-model-name",
|
||||
type=str,
|
||||
default="trt_engines_bfloat16",
|
||||
help="Model name to use with OpenAI API (for trtllm-serve backend)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -180,6 +221,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||
prompt_text_after_apply_template_list = []
|
||||
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
|
||||
chat_list = []
|
||||
for _, item in enumerate(batch):
|
||||
audio_processing_start_time = time.time()
|
||||
prompt_text, target_text = (
|
||||
@@ -237,6 +279,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
{"role": "user", "content": full_text_list[i]},
|
||||
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
|
||||
]
|
||||
chat_list.append(chat)
|
||||
|
||||
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
|
||||
|
||||
@@ -265,6 +308,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
"audio_processing_time": total_audio_processing_time,
|
||||
"speech_tokenization_time": total_speech_tokenization_time,
|
||||
"text_tokenization_time": total_text_tokenization_time,
|
||||
"chat_list": chat_list
|
||||
}
|
||||
|
||||
|
||||
@@ -318,6 +362,9 @@ def main(args):
|
||||
elif args.backend == "vllm":
|
||||
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
||||
runner = None
|
||||
elif args.backend == "trtllm-serve":
|
||||
model = None
|
||||
runner = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
@@ -452,6 +499,35 @@ def main(args):
|
||||
print(outputs)
|
||||
for j, output in enumerate(outputs):
|
||||
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
|
||||
elif args.backend == "trtllm-serve":
|
||||
if args.batch_size > 1:
|
||||
outputs = asyncio.run(send_batch_requests_async(
|
||||
args.openai_api_base,
|
||||
args.openai_model_name,
|
||||
batch["chat_list"],
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.top_k,
|
||||
))
|
||||
else:
|
||||
outputs = []
|
||||
for i, chat in enumerate(batch["chat_list"]):
|
||||
payload = {
|
||||
"model": args.openai_model_name,
|
||||
"messages": chat,
|
||||
"max_tokens": 2048,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"repetition_penalty": 1.1,
|
||||
"stop": ["<|eos1|>", "<|eos|>"],
|
||||
"stream": False,
|
||||
}
|
||||
response = requests.post(args.openai_api_base, json=payload)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
generated_content = response_json['choices'][0]['message']['content']
|
||||
outputs.append(generated_content)
|
||||
|
||||
llm_end_time = time.time()
|
||||
total_llm_time += (llm_end_time - llm_start_time)
|
||||
@@ -459,10 +535,21 @@ def main(args):
|
||||
items_for_token_2wav = []
|
||||
for i in range(len(batch["ids"])):
|
||||
llm_post_processing_start_time = time.time()
|
||||
input_length = len(batch["input_ids"][i])
|
||||
generated_ids = outputs[i][input_length:]
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
if args.backend == "trtllm-serve":
|
||||
speech_tokens_str = outputs[i].strip().split('><')
|
||||
if len(speech_tokens_str) > 1:
|
||||
speech_tokens_str = [
|
||||
t if t.startswith('<') else '<' + t for t in speech_tokens_str
|
||||
]
|
||||
speech_tokens_str = [
|
||||
t if t.endswith('>') else t + '>' for t in speech_tokens_str
|
||||
]
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
else:
|
||||
input_length = len(batch["input_ids"][i])
|
||||
generated_ids = outputs[i][input_length:]
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
print(i, speech_ids)
|
||||
if len(speech_ids) == 0:
|
||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||
@@ -558,6 +645,8 @@ if __name__ == "__main__":
|
||||
from tensorrt_llm.runtime import ModelRunnerCpp
|
||||
elif args.backend == "hf":
|
||||
from transformers import AutoModelForCausalLM
|
||||
elif args.backend == "trtllm-serve":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user