This commit is contained in:
root
2025-07-29 08:39:41 +00:00
parent 1b8d194b67
commit 07cbc51cd1
8 changed files with 165 additions and 157 deletions

View File

@@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
@@ -66,6 +67,7 @@ class CosyVoice2:
trt_concurrent,
self.fp16)
class CosyVoice2Model:
def __init__(self,
@@ -109,16 +111,17 @@ class CosyVoice2Model:
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
class TritonPythonModel:
"""Triton Python model for vocoder.
This model takes global and semantic tokens as input and generates audio waveforms
using the BiCodec vocoder.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
@@ -126,24 +129,23 @@ class TritonPythonModel:
parameters = json.loads(args['model_config'])['parameters']
model_params = {key: value["string_value"] for key, value in parameters.items()}
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2(
model_dir, load_jit=True, load_trt=True, fp16=True
)
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated waveforms
"""
@@ -163,7 +165,7 @@ class TritonPythonModel:
# shift the speech tokens according to the original vocab size
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
@@ -189,9 +191,5 @@ class TritonPythonModel:
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
responses.append(inference_response)
return responses