add prompt audio cache

This commit is contained in:
yuekaiz
2025-09-05 13:54:39 +08:00
parent 86e7c2d731
commit 6971536358
7 changed files with 112 additions and 53 deletions

View File

@@ -77,16 +77,19 @@ The following results were obtained by decoding on a single L20 GPU with 26 prom
**Streaming TTS (First Chunk Latency)**
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|
| Streaming, Decoupled=True | 1 | 220.43 | 218.07 | 0.1237 |
| Streaming, Decoupled=True | 2 | 476.97 | 369.25 | 0.1022 |
| Streaming, Decoupled=True | 4 | 1107.34 | 1243.75| 0.0922 |
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
**Offline TTS (Full Sentence Latency)**
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|---|
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
### OpenAI-Compatible Server

View File

@@ -257,7 +257,13 @@ def get_args():
default=0.1,
help="Chunk overlap duration for streaming reconstruction (in seconds)."
)
# --- End Added arguments ---
parser.add_argument(
"--use-spk2info-cache",
type=bool,
default=False,
help="Use spk2info cache for reference audio.",
)
return parser.parse_args()
@@ -283,7 +289,8 @@ def prepare_request_input_output(
reference_text,
target_text,
sample_rate=16000,
padding_duration: int = None # Optional padding for offline mode
padding_duration: int = None, # Optional padding for offline mode
use_spk2info_cache: bool = False
):
"""Prepares inputs for Triton inference (offline or streaming)."""
assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -330,7 +337,8 @@ def prepare_request_input_output(
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
if use_spk2info_cache:
inputs = inputs[-1:]
return inputs, outputs
@@ -453,6 +461,7 @@ async def send_streaming(
save_sample_rate: int = 16000,
chunk_overlap_duration: float = 0.1,
padding_duration: int = None,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
@@ -478,7 +487,8 @@ async def send_streaming(
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
request_id = str(uuid.uuid4())
user_data = UserData()
@@ -534,6 +544,7 @@ async def send(
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
@@ -552,7 +563,8 @@ async def send(
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
sequence_id = 100000000 + i + task_id * 10
start = time.time()
@@ -691,6 +703,7 @@ async def main():
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
use_spk2info_cache=args.use_spk2info_cache,
)
)
elif args.mode == "streaming":
@@ -706,6 +719,7 @@ async def main():
padding_duration=10,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
chunk_overlap_duration=args.chunk_overlap_duration,
use_spk2info_cache=args.use_spk2info_cache,
)
)
# --- End Task Creation ---

View File

@@ -43,6 +43,7 @@ import torchaudio
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
@@ -81,6 +82,12 @@ class TritonPythonModel:
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
def forward_llm(self, input_ids):
"""
Prepares the response from the language model based on the provided
@@ -220,11 +227,11 @@ class TritonPythonModel:
def forward_token2wav(
self,
prompt_speech_tokens: torch.Tensor,
prompt_speech_feat: torch.Tensor,
prompt_spk_embedding: torch.Tensor,
target_speech_tokens: torch.Tensor,
request_id: str,
prompt_speech_tokens: torch.Tensor = None,
prompt_speech_feat: torch.Tensor = None,
prompt_spk_embedding: torch.Tensor = None,
token_offset: int = None,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
@@ -238,12 +245,9 @@ class TritonPythonModel:
Returns:
Generated waveform tensor
"""
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
inputs_tensor = [target_speech_tokens_tensor]
if token_offset is not None:
assert finalize is not None
@@ -252,6 +256,13 @@ class TritonPythonModel:
inputs_tensor.append(token_offset_tensor)
inputs_tensor.append(finalize_tensor)
if prompt_spk_embedding is not None:
assert prompt_speech_feat is not None
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav',
@@ -318,25 +329,30 @@ class TritonPythonModel:
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
else:
# using pre-cached reference text
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
@@ -350,7 +366,6 @@ class TritonPythonModel:
# Generate semantic tokens with LLM
generated_ids_iter = self.forward_llm(input_ids)
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
if self.decoupled:
response_sender = request.get_response_sender()
@@ -380,8 +395,9 @@ class TritonPythonModel:
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding,
this_tts_speech_token, request_id, token_offset, False)
this_tts_speech_token, request_id, prompt_speech_tokens,
prompt_speech_feat, prompt_spk_embedding, token_offset, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
@@ -414,7 +430,7 @@ class TritonPythonModel:
time.sleep(0.02)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
@@ -428,7 +444,7 @@ class TritonPythonModel:
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))

View File

@@ -37,16 +37,19 @@ input [
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: true
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: true
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
optional: true
},
{
name: "target_text"

View File

@@ -187,6 +187,12 @@ class TritonPythonModel:
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
)
spk_info_path = os.path.join(model_dir, "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_dir}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
@@ -202,17 +208,23 @@ class TritonPythonModel:
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
if prompt_speech_tokens_tensor is not None:
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
else:
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
# shift the speech tokens according to the original vocab size
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.

View File

@@ -35,16 +35,19 @@ input [
name: "prompt_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
optional: true
},
{
name: "prompt_speech_feat"
data_type: TYPE_FP16
dims: [-1, 80]
optional: true
},
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
optional: true
},
{
name: "token_offset"

View File

@@ -15,6 +15,8 @@ trt_engines_dir=./trt_engines_${trt_dtype}
model_repo=./model_repo_cosyvoice2
use_spk2info_cache=True
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
echo "Cloning CosyVoice"
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
@@ -27,6 +29,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading CosyVoice2-0.5B"
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
fi
@@ -57,10 +61,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
cosyvoice2_dir="cosyvoice2"
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/tensorrt_llm $model_repo
cp -r ./model_repo/token2wav $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
if [ $use_spk2info_cache == "False" ]; then
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
fi
ENGINE_PATH=$trt_engines_dir
MAX_QUEUE_DELAY_MICROSECONDS=0
@@ -71,11 +77,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
DECOUPLED_MODE=True # True for streaming, False for offline
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
if [ $use_spk2info_cache == "False" ]; then
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
@@ -94,7 +101,7 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Running benchmark client grpc"
num_task=1
num_task=4
mode=streaming
BLS_INSTANCE_NUM=4
@@ -104,6 +111,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--model-name cosyvoice2 \
--num-tasks $num_task \
--mode $mode \
--use-spk2info-cache $use_spk2info_cache \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
fi