mark stateless token2wav

This commit is contained in:
root
2025-09-26 14:51:41 +08:00
parent 482464ea27
commit 31a0adc73d
5 changed files with 266 additions and 121 deletions

View File

@@ -103,39 +103,91 @@ class TritonPythonModel:
List of inference responses containing generated waveforms
"""
responses = []
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
# shift the speech tokens according to the original vocab size
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
request_id = request.request_id()
# Get inputs
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens")
target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack())
target_speech_tokens = target_speech_tokens.squeeze().tolist()
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request_id = request.request_id()
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array)
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item()
wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav)
# wav = wav.to(self.device)
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
# Handle cache
conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache")
if conformer_cnn_cache is not None:
self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack())
conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache")
self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1)
estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache")
self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0)
generated_wave = audio_hat.squeeze(0).cpu().numpy()
estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache")
self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0)
mel_np = pb_utils.get_input_tensor_by_name(request, "mel")
self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack())
source_np = pb_utils.get_input_tensor_by_name(request, "source")
self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack())
speech_np = pb_utils.get_input_tensor_by_name(request, "speech")
self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack())
# Forward pass
audio_hat = self.token2wav_model.forward_streaming(
target_speech_tokens,
finalize,
request_id=request_id,
speaker_id=f"{spk_id}",
prompt_audio=wav,
prompt_audio_sample_rate=16000
)
# Prepare outputs
outputs = []
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
responses.append(inference_response)
outputs.append(wav_tensor)
if request_id in self.token2wav_model.streaming_flow_cache:
cache = self.token2wav_model.streaming_flow_cache[request_id]
hifigan_cache = self.token2wav_model.hift_cache_dict[request_id]
conformer_cnn_cache = cache['conformer_cnn_cache']
conformer_att_cache = cache['conformer_att_cache'].transpose(0,1)
estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0)
estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0)
mel = hifigan_cache['mel']
source = hifigan_cache['source']
speech = hifigan_cache['speech']
outputs.extend([
pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)),
pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)),
pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)),
pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)),
pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)),
pb_utils.Tensor.from_dlpack("source", to_dlpack(source)),
pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)),
])
else:
outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("mel", np.array([], dtype=np.float32)),
pb_utils.Tensor("source", np.array([], dtype=np.float32)),
pb_utils.Tensor("speech", np.array([], dtype=np.float32)),
])
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
responses.append(inference_response)
return responses
def finalize(self):
self.logger.log_info("Finalizing Token2WavDiT model")

View File

@@ -372,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
):
if speaker_id not in self.speaker_cache:
# if 1:
assert prompt_audio is not None, "prompt_audio is required for new speaker"
assert prompt_audio_sample_rate == 16000
@@ -384,20 +383,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
# if speaker_id not in self.speaker_cache:
# if 1:
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
print(f"speaker_id {speaker_id} added to cache")
# get a clone of cache dict ['estimator_att_cache'] and later check if it would be change
att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone()
cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone()
conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone()
conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone()
if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict(
@@ -405,6 +394,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
source = torch.zeros(1, 1, 0, device='cuda'),
speech = torch.zeros(1, 0, device='cuda'),
)
# else:
# for k, v in self.streaming_flow_cache[request_id].items():
# print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
# for k, v in self.hift_cache_dict[request_id].items():
# print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
# breakpoint()
current_request_cache = self.streaming_flow_cache[request_id]
@@ -420,33 +415,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
n_timesteps=10,
)
# get the original att_cache
original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache']
original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache']
original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache']
original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache']
if not torch.allclose(original_att_cache, att_cache_clone):
print("att_cache changed")
# print the last 10 elements of original_att_cache and att_cache_clone
print(original_att_cache[:, :, :, -10:])
print(att_cache_clone[:, :, :, -10:])
breakpoint()
if not torch.allclose(original_cnn_cache, cnn_cache_clone):
print("cnn_cache changed")
print(original_cnn_cache[..., -10:])
print(cnn_cache_clone[..., -10:])
breakpoint()
if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone):
print("conformer_cnn_cache changed")
print(original_conformer_cnn_cache[..., -10:])
print(conformer_cnn_cache_clone[..., -10:])
breakpoint()
if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone):
print("conformer_att_cache changed")
print(original_conformer_att_cache[..., -10:])
print(conformer_att_cache_clone[..., -10:])
breakpoint()
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
@@ -482,7 +450,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
assert request_id in self.streaming_flow_cache
self.streaming_flow_cache.pop(request_id)
self.hift_cache_dict.pop(request_id)
# breakpoint()
return speech
def collate_fn(batch):