mark stateless token2wav

This commit is contained in:
root
2025-09-26 14:51:41 +08:00
parent 482464ea27
commit 31a0adc73d
5 changed files with 266 additions and 121 deletions

View File

@@ -43,6 +43,7 @@ import torchaudio
from matcha.utils.audio import mel_spectrogram from matcha.utils.audio import mel_spectrogram
from datetime import datetime
ORIGINAL_VOCAB_SIZE = 151663 ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1) torch.set_num_threads(1)
@@ -86,6 +87,7 @@ class TritonPythonModel:
model_params = {k: v["string_value"] for k, v in parameters.items()} model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}") self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
# self.dynamic_chunk_strategy = "equal"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer # Initialize tokenizer
@@ -105,7 +107,9 @@ class TritonPythonModel:
if not os.path.exists(spk_info_path): if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
# self.default_spk_info = spk_info["001"] self.default_spk_info = spk_info["001"]
self.http_client = httpx.AsyncClient()
self.runtime_cache = {}
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation.""" """Converts a tensor or list of speech token IDs to a string representation."""
@@ -131,7 +135,6 @@ class TritonPythonModel:
{"role": "user", "content": full_text}, {"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_speech_tokens_str} {"role": "assistant", "content": prompt_speech_tokens_str}
] ]
print(chat)
payload = { payload = {
"model": "trt_engines_bfloat16", "model": "trt_engines_bfloat16",
@@ -148,31 +151,33 @@ class TritonPythonModel:
api_base = "http://localhost:8000/v1/chat/completions" api_base = "http://localhost:8000/v1/chat/completions"
buffer = "" buffer = ""
async with httpx.AsyncClient() as client: async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
async with client.stream("POST", api_base, json=payload, timeout=None) as response: print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
response.raise_for_status() print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
async for line in response.aiter_lines(): response.raise_for_status()
if line.startswith("data: "): async for line in response.aiter_lines():
line_data = line[len("data: "):].strip() if line.startswith("data: "):
if line_data == "[DONE]": line_data = line[len("data: "):].strip()
break if line_data == "[DONE]":
try: break
json_data = json.loads(line_data) try:
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") json_data = json.loads(line_data)
if content: content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
buffer += content if content:
while True: buffer += content
match = re.search(r"<\|s_(\d+)\|>", buffer) print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
if not match: while True:
break match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1)) token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id yield final_id
buffer = buffer[match.end():] buffer = buffer[match.end():]
except json.JSONDecodeError: except json.JSONDecodeError:
self.logger.log_info(f"Skipping non-JSON line: {line_data}") self.logger.log_info(f"Skipping non-JSON line: {line_data}")
continue continue
# Process any remaining complete tokens in the buffer after the stream ends # Process any remaining complete tokens in the buffer after the stream ends
while True: while True:
@@ -236,7 +241,7 @@ class TritonPythonModel:
return prompt_spk_embedding return prompt_spk_embedding
def forward_token2wav( async def forward_token2wav(
self, self,
index: int, index: int,
target_speech_tokens: torch.Tensor, target_speech_tokens: torch.Tensor,
@@ -259,19 +264,56 @@ class TritonPythonModel:
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# optional cache inputs
if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None:
# inputs_tensor.extend([
# pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()),
# pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()),
# pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()),
# pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()),
# pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()),
# pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()),
# pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()),
# ])
inputs_tensor.extend([
self.runtime_cache[request_id]["conformer_cnn_cache"],
self.runtime_cache[request_id]["conformer_att_cache"],
self.runtime_cache[request_id]["estimator_cnn_cache"],
self.runtime_cache[request_id]["estimator_att_cache"],
self.runtime_cache[request_id]["mel"],
self.runtime_cache[request_id]["source"],
self.runtime_cache[request_id]["speech"],
])
# Create and execute inference request # Create and execute inference request
inference_request = pb_utils.InferenceRequest( inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit', model_name='token2wav_dit',
requested_output_names=['waveform'], requested_output_names=[
"waveform",
"conformer_cnn_cache",
"conformer_att_cache",
"estimator_cnn_cache",
"estimator_att_cache",
"mel",
"source",
"speech",
],
inputs=inputs_tensor, inputs=inputs_tensor,
request_id=request_id, request_id=request_id,
parameters={"priority": index+1}, parameters={"priority": index+1},
) )
inference_response = inference_request.exec() inference_response = await inference_request.async_exec()
if inference_response.has_error(): if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message()) raise pb_utils.TritonModelException(inference_response.error().message())
self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache")
self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache")
self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache")
self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache")
self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel")
self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source")
self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech")
# Extract and convert output waveform # Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
@@ -297,6 +339,16 @@ class TritonPythonModel:
async def _process_request(self, request): async def _process_request(self, request):
request_id = request.request_id() request_id = request.request_id()
if request_id not in self.runtime_cache:
self.runtime_cache[request_id] = {
"conformer_cnn_cache": None,
"conformer_att_cache": None,
"estimator_cnn_cache": None,
"estimator_att_cache": None,
"mel": None,
"source": None,
"speech": None,
}
# Extract input tensors # Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
@@ -308,6 +360,7 @@ class TritonPythonModel:
wav_tensor = wav.as_numpy() wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample) speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
@@ -316,7 +369,7 @@ class TritonPythonModel:
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8') reference_text = reference_text[0][0].decode('utf-8')
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# reference_text = self.default_spk_info["prompt_text"] # reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
@@ -333,6 +386,7 @@ class TritonPythonModel:
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8') target_text = target_text[0][0].decode('utf-8')
print(f"target_text: {target_text}, time: {datetime.now()}")
if self.decoupled: if self.decoupled:
response_sender = request.get_response_sender() response_sender = request.get_response_sender()
@@ -341,7 +395,7 @@ class TritonPythonModel:
token_offset, chunk_index = 0, 0 token_offset, chunk_index = 0, 0
start_time = time.time() start_time = time.time()
this_token_hop_len = self.token_hop_len this_token_hop_len = self.token_hop_len
print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
async for generated_ids in self.forward_llm_async( async for generated_ids in self.forward_llm_async(
target_text=target_text, target_text=target_text,
reference_text=reference_text, reference_text=reference_text,
@@ -350,18 +404,18 @@ class TritonPythonModel:
if not generated_ids: if not generated_ids:
break break
semantic_token_ids_arr.append(generated_ids) semantic_token_ids_arr.append(generated_ids)
print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
while True: while True:
pending_num = len(semantic_token_ids_arr) - token_offset pending_num = len(semantic_token_ids_arr) - token_offset
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
sub_tts_speech = self.forward_token2wav( sub_tts_speech = await self.forward_token2wav(
chunk_index, chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False this_tts_speech_token, request_id, wav, wav_len, False
) )
print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response) response_sender.send(inference_response)
@@ -371,6 +425,8 @@ class TritonPythonModel:
if self.dynamic_chunk_strategy == "exponential": if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "equal":
this_token_hop_len = self.token_hop_len
elif self.dynamic_chunk_strategy == "time_based": elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time cost_time = time.time() - start_time
@@ -393,29 +449,13 @@ class TritonPythonModel:
break break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response) response_sender.send(inference_response)
if request_id in self.runtime_cache:
## debug del self.runtime_cache[request_id]
## save semantic_token_ids_arr and reference_text, target_text to a single json file self.logger.log_info(f"Deleted cache for request_id: {request_id}")
# save into a torch .pt
# for i, item in enumerate(semantic_token_ids_arr):
# semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE
# import json
# data = {
# "semantic_token_ids_arr": semantic_token_ids_arr,
# "reference_text": reference_text,
# "target_text": target_text
# }
# with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f:
# torch.save(data, f)
# with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f:
# json.dump(data, f)
# ##
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end") self.logger.log_info("send tritonserver_response_complete_final to end")
else: else:
@@ -436,3 +476,8 @@ class TritonPythonModel:
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
return None return None
def finalize(self):
self.logger.log_info("Finalizing CosyVoice DIT model")
if hasattr(self, "http_client"):
asyncio.run(self.http_client.aclose())

View File

@@ -31,7 +31,7 @@ parameters [
value: {string_value:"${model_dir}"} value: {string_value:"${model_dir}"}
} }
] ]
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
input [ input [
{ {
name: "reference_wav" name: "reference_wav"

View File

@@ -103,39 +103,91 @@ class TritonPythonModel:
List of inference responses containing generated waveforms List of inference responses containing generated waveforms
""" """
responses = [] responses = []
# Process each request in batch
for request in requests: for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
# shift the speech tokens according to the original vocab size
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens.squeeze().tolist()
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request_id = request.request_id() request_id = request.request_id()
# Get inputs
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens")
target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack())
target_speech_tokens = target_speech_tokens.squeeze().tolist()
wav_array = pb_utils.get_input_tensor_by_name( finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request, "reference_wav").as_numpy() wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name( wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item()
request, "reference_wav_len").as_numpy().item() wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0)
wav_array = torch.from_numpy(wav_array)
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav) spk_id = get_spk_id_from_prompt_audio(wav)
# wav = wav.to(self.device)
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) # Handle cache
conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache")
if conformer_cnn_cache is not None:
self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack())
generated_wave = audio_hat.squeeze(0).cpu().numpy() conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache")
self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1)
estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache")
self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0)
estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache")
self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0)
mel_np = pb_utils.get_input_tensor_by_name(request, "mel")
self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack())
source_np = pb_utils.get_input_tensor_by_name(request, "source")
self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack())
speech_np = pb_utils.get_input_tensor_by_name(request, "speech")
self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack())
# Forward pass
audio_hat = self.token2wav_model.forward_streaming(
target_speech_tokens,
finalize,
request_id=request_id,
speaker_id=f"{spk_id}",
prompt_audio=wav,
prompt_audio_sample_rate=16000
)
# Prepare outputs
outputs = []
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) outputs.append(wav_tensor)
responses.append(inference_response)
if request_id in self.token2wav_model.streaming_flow_cache:
cache = self.token2wav_model.streaming_flow_cache[request_id]
hifigan_cache = self.token2wav_model.hift_cache_dict[request_id]
conformer_cnn_cache = cache['conformer_cnn_cache']
conformer_att_cache = cache['conformer_att_cache'].transpose(0,1)
estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0)
estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0)
mel = hifigan_cache['mel']
source = hifigan_cache['source']
speech = hifigan_cache['speech']
outputs.extend([
pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)),
pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)),
pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)),
pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)),
pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)),
pb_utils.Tensor.from_dlpack("source", to_dlpack(source)),
pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)),
])
else:
outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("mel", np.array([], dtype=np.float32)),
pb_utils.Tensor("source", np.array([], dtype=np.float32)),
pb_utils.Tensor("speech", np.array([], dtype=np.float32)),
])
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
responses.append(inference_response)
return responses return responses
def finalize(self):
self.logger.log_info("Finalizing Token2WavDiT model")

View File

@@ -372,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
): ):
if speaker_id not in self.speaker_cache: if speaker_id not in self.speaker_cache:
# if 1:
assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio is not None, "prompt_audio is required for new speaker"
assert prompt_audio_sample_rate == 16000 assert prompt_audio_sample_rate == 16000
@@ -384,20 +383,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
# if speaker_id not in self.speaker_cache:
# if 1:
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
print(f"speaker_id {speaker_id} added to cache") print(f"speaker_id {speaker_id} added to cache")
# get a clone of cache dict ['estimator_att_cache'] and later check if it would be change
att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone()
cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone()
conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone()
conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone()
if request_id not in self.streaming_flow_cache: if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict( self.hift_cache_dict[request_id] = dict(
@@ -405,6 +394,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
source = torch.zeros(1, 1, 0, device='cuda'), source = torch.zeros(1, 1, 0, device='cuda'),
speech = torch.zeros(1, 0, device='cuda'), speech = torch.zeros(1, 0, device='cuda'),
) )
# else:
# for k, v in self.streaming_flow_cache[request_id].items():
# print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
# for k, v in self.hift_cache_dict[request_id].items():
# print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
# breakpoint()
current_request_cache = self.streaming_flow_cache[request_id] current_request_cache = self.streaming_flow_cache[request_id]
@@ -420,33 +415,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
n_timesteps=10, n_timesteps=10,
) )
# get the original att_cache
original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache']
original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache']
original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache']
original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache']
if not torch.allclose(original_att_cache, att_cache_clone):
print("att_cache changed")
# print the last 10 elements of original_att_cache and att_cache_clone
print(original_att_cache[:, :, :, -10:])
print(att_cache_clone[:, :, :, -10:])
breakpoint()
if not torch.allclose(original_cnn_cache, cnn_cache_clone):
print("cnn_cache changed")
print(original_cnn_cache[..., -10:])
print(cnn_cache_clone[..., -10:])
breakpoint()
if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone):
print("conformer_cnn_cache changed")
print(original_conformer_cnn_cache[..., -10:])
print(conformer_cnn_cache_clone[..., -10:])
breakpoint()
if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone):
print("conformer_att_cache changed")
print(original_conformer_att_cache[..., -10:])
print(conformer_att_cache_clone[..., -10:])
breakpoint()
self.streaming_flow_cache[request_id] = new_streaming_flow_cache self.streaming_flow_cache[request_id] = new_streaming_flow_cache
@@ -482,7 +450,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
assert request_id in self.streaming_flow_cache assert request_id in self.streaming_flow_cache
self.streaming_flow_cache.pop(request_id) self.streaming_flow_cache.pop(request_id)
self.hift_cache_dict.pop(request_id) self.hift_cache_dict.pop(request_id)
# breakpoint()
return speech return speech
def collate_fn(batch): def collate_fn(batch):

View File

@@ -15,11 +15,14 @@
name: "token2wav_dit" name: "token2wav_dit"
backend: "python" backend: "python"
max_batch_size: ${triton_max_batch_size} max_batch_size: ${triton_max_batch_size}
dynamic_batching { dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds} max_queue_delay_microseconds: ${max_queue_delay_microseconds}
priority_levels: 10 priority_levels: 10
default_priority_level: 10 default_priority_level: 10
} }
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
parameters [ parameters [
{ {
key: "model_dir", key: "model_dir",
@@ -49,6 +52,48 @@ input [
dims: [ 1 ] dims: [ 1 ]
reshape: { shape: [ ] } reshape: { shape: [ ] }
optional: true optional: true
},
{
name: "conformer_cnn_cache"
data_type: TYPE_FP16
dims: [ 512, -1 ]
optional: true
},
{
name: "conformer_att_cache"
data_type: TYPE_FP16
dims: [ 10, 8, -1, 128 ]
optional: true
},
{
name: "estimator_cnn_cache"
data_type: TYPE_FP16
dims: [ 10, 16, -1, 1024, 2 ]
optional: true
},
{
name: "estimator_att_cache"
data_type: TYPE_FP16
dims: [ 10, 16, -1, 8, -1, 128 ]
optional: true
},
{
name: "mel"
data_type: TYPE_FP32
dims: [ 80, -1 ]
optional: true
},
{
name: "source"
data_type: TYPE_FP32
dims: [ 1, -1 ]
optional: true
},
{
name: "speech"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
} }
] ]
output [ output [
@@ -56,6 +101,41 @@ output [
name: "waveform" name: "waveform"
data_type: TYPE_FP32 data_type: TYPE_FP32
dims: [ -1 ] dims: [ -1 ]
},
{
name: "conformer_cnn_cache"
data_type: TYPE_FP16
dims: [ 512, -1 ]
},
{
name: "conformer_att_cache"
data_type: TYPE_FP16
dims: [ 10, 8, -1, 128 ]
},
{
name: "estimator_cnn_cache"
data_type: TYPE_FP16
dims: [ 10, 16, -1, 1024, 2 ]
},
{
name: "estimator_att_cache"
data_type: TYPE_FP16
dims: [ 10, 16, -1, 8, -1, 128 ]
},
{
name: "mel"
data_type: TYPE_FP32
dims: [ 80, -1 ]
},
{
name: "source"
data_type: TYPE_FP32
dims: [ 1, -1 ]
},
{
name: "speech"
data_type: TYPE_FP32
dims: [ -1 ]
} }
] ]