clean code

This commit is contained in:
Yuekai Zhang
2025-07-27 23:33:10 -07:00
parent 5427c274e3
commit 178da09993
6 changed files with 23 additions and 29 deletions

View File

@@ -33,6 +33,7 @@ import os
import numpy as np
import s3tokenizer
ORIGINAL_VOCAB_SIZE = 151663
class TritonPythonModel:
"""Triton Python model for audio tokenization.
@@ -81,7 +82,7 @@ class TritonPythonModel:
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
codes = codes.clone() + 151663
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
responses = []
for i in range(len(requests)):

View File

@@ -199,8 +199,6 @@ class TritonPythonModel:
Returns:
Generated waveform tensor
"""
print(prompt_speech_tokens.shape, prompt_speech_feat.shape, prompt_spk_embedding.shape, target_speech_tokens.shape)
# Convert tensors to Triton format
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
@@ -228,9 +226,7 @@ class TritonPythonModel:
prompt = self.prompt_template.format(input_text=total_text)
input_ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([input_ids], dtype=torch.int32)
print(input_ids.shape, "before cat")
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
print(input_ids.shape, "after cat", prompt_speech_tokens.shape)
return input_ids
def _extract_spk_embedding(self, speech):
@@ -271,23 +267,15 @@ class TritonPythonModel:
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
# TODO: FIX ME
wav_tensor = wav.as_numpy()
print(wav_tensor.shape, "wav_tensor")
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
print(wav_tensor.shape, "wav_tensor after")
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
print(speech_feat.shape, "speech_feat")
print(prompt_speech_tokens.shape, "prompt_speech_tokens here")
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
print(prompt_speech_tokens.shape, "prompt_speech_tokens after")
print(speech_feat.shape, "speech_feat after")
print(token_len, "token_len")
# Extract text inputs
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')

View File

@@ -38,13 +38,11 @@ import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
#import sys
#sys.path.append("/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS")
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
class CosyVoice2:
@@ -162,8 +160,9 @@ class TritonPythonModel:
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens = prompt_speech_tokens - 151663
target_speech_tokens = target_speech_tokens - 151663
# shift the speech tokens according to the original vocab size
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,