# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pytriton server for token2wav conversion and ASR""" from datasets import load_dataset from cosyvoice.cli.cosyvoice import CosyVoice2 from omnisense.models import OmniSenseVoiceSmall from pytriton.proxy.types import Request from pytriton.triton import Triton, TritonConfig from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor from pytriton.decorators import batch import argparse import io import logging from typing import Any, List import numpy as np import torch from scipy.signal import resample import sys import random import re from jiwer import wer from pypinyin import lazy_pinyin, Style from tn.chinese.normalizer import Normalizer as ZhNormalizer # Chinese text normalizer (cached globally) zh_tn_model = ZhNormalizer( cache_dir="./cache", remove_erhua=False, remove_interjections=False, remove_puncts=True, overwrite_cache=True, ) sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") logger = logging.getLogger("token2wav_asr_server") class _ASR_Server: """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" def __init__(self, device_id: int): self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id) @batch def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray): """ WAV: np.ndarray, WAV_LENS: np.ndarray LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu """ logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape) wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))] results = self._model.transcribe_single_batch( wavs, language="zh", textnorm="woitn", ) texts = [result.text for result in results] transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") return {"TRANSCRIPTS": transcripts} def audio_decode_cosyvoice2( audio_tokens, prompt_text, prompt_speech_16k, codec_decoder ): """ Generate audio from tokens with optional tone and prompt embedding. """ model_inputs_dict = codec_decoder.frontend.frontend_zero_shot( "empty", prompt_text, prompt_speech_16k, 24000 ) tts_mel, _ = codec_decoder.model.flow.inference( token=audio_tokens.to(codec_decoder.model.device), token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( codec_decoder.model.device ), prompt_token=model_inputs_dict["flow_prompt_speech_token"].to( codec_decoder.model.device ), prompt_token_len=torch.tensor( [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32 ).to(codec_decoder.model.device), prompt_feat=model_inputs_dict["prompt_speech_feat"].to( codec_decoder.model.device ), prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to( codec_decoder.model.device ), embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device), finalize=True, ) audio_hat, _ = codec_decoder.model.hift.inference( speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) ) return audio_hat def get_random_prompt_from_dataset(dataset): """ Get random prompt text and speech from the pre-loaded dataset. Returns (prompt_text, prompt_speech_16k) """ random_idx = random.randint(0, len(dataset) - 1) sample = dataset[random_idx] # Extract audio data audio_data = sample["audio"] audio_array = audio_data["array"] sample_rate = audio_data["sampling_rate"] # Convert audio to 16kHz if needed if sample_rate != 16000: num_samples = int(len(audio_array) * (16000 / sample_rate)) audio_array = resample(audio_array, num_samples) # Convert to torch tensor prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0) prompt_text = sample["text"] # remove space in prompt_text prompt_text = prompt_text.replace(" ", "") return prompt_text, prompt_speech_16k class _Token2Wav_ASR: """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" def __init__(self, device_id: int): self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id) self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"] # Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model # CosyVoice2 internally uses generic "cuda" device, so we first switch the # current CUDA context to the desired card before the object is created. # Afterwards, all parameters loaded with the generic "cuda" device will # reside on this GPU. We keep the selected id in `self.device_id` and # will set the context again for every forward call to avoid race # conditions when several instances are used in the same process. self.device_id = device_id # Construct the TTS codec decoder under the correct CUDA device context with torch.cuda.device(self.device_id): self.codec_decoder = CosyVoice2( "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True ) @batch def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray): """ WAV: np.ndarray, WAV_LENS: np.ndarray LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu """ # Ensure the default CUDA device is set correctly for this invocation torch.cuda.set_device(self.device_id) if self.device_id == 0: print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}") tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))] # Decode ground-truth text strings (BYTES → str) if GT_TEXT.ndim == 2: gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))] else: gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))] wavs = [] for tokens in tokens_list: prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset) audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0) audio_hat = audio_decode_cosyvoice2( audio_tokens, prompt_text, prompt_speech_16k, self.codec_decoder, ) # resample to 16000 using soundfile audio_hat = audio_hat.squeeze(0).float().cpu() audio_hat = audio_hat.numpy() num_samples = int(len(audio_hat) * (16000 / 24000)) audio_hat = resample(audio_hat, num_samples) wavs.append(audio_hat) results = self.asr_model.transcribe_single_batch( wavs, language="zh", textnorm="woitn", ) texts = [result.text for result in results] # ---------------- Reward computation ---------------- rewards = [] for gt_text, hyp_text in zip(gt_texts, texts): gt_norm = zh_tn_model.normalize(gt_text).lower() hyp_norm = zh_tn_model.normalize(hyp_text).lower() gt_pinyin = lazy_pinyin( gt_norm, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True, ) hyp_pinyin = lazy_pinyin( hyp_norm, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True, ) c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin))) reward_val = 1.0 - np.tanh(3.0 * c) reward_val = max(0.0, min(1.0, reward_val)) rewards.append(reward_val) print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}") transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1) return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts} def _infer_function_factory(device_ids: List[int], model_name: str): """Creates a list of inference functions, one for each requested device ID.""" infer_funcs = [] for device_id in device_ids: if model_name == "sensevoice": infer_funcs.append(_ASR_Server(device_id=device_id)) else: infer_funcs.append(_Token2Wav_ASR(device_id=device_id)) return infer_funcs def main(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--max-batch-size", type=int, default=32, help="Batch size of request.", required=False, ) parser.add_argument( "--verbose", action="store_true", default=False, ) parser.add_argument( "--number-of-instances-per-device", type=int, default=1, help="Number of model instances to load.", required=False, ) parser.add_argument( "--number-of-devices", type=int, default=8, help="Number of devices to use.", ) parser.add_argument( "--model-name", type=str, default="token2wav_asr", choices=["token2wav_asr", "sensevoice"], help="Model name.", ) args = parser.parse_args() log_level = logging.DEBUG if args.verbose else logging.INFO logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s") triton_config = TritonConfig( http_port=8000, grpc_port=8001, metrics_port=8002, ) device_ids = list(range(args.number_of_devices)) device_ids = device_ids * args.number_of_instances_per_device with Triton(config=triton_config) as triton: logger.info("Loading SenseVoice model on device ids: %s", device_ids) if args.model_name == "sensevoice": triton.bind( model_name="sensevoice", infer_func=_infer_function_factory(device_ids, args.model_name), inputs=[ Tensor(name="WAV", dtype=np.float32, shape=(-1,)), Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)), Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)), Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)), ], outputs=[ Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)), ], config=ModelConfig( max_batch_size=args.max_batch_size, batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms ), strict=True, ) else: triton.bind( model_name="token2wav_asr", infer_func=_infer_function_factory(device_ids, args.model_name), inputs=[ Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)), Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)), Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)), ], outputs=[ Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)), Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)), ], config=ModelConfig( max_batch_size=args.max_batch_size, batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms ), strict=True, ) logger.info("Serving inference") triton.serve() if __name__ == "__main__": main()