This commit is contained in:
root
2025-07-29 08:40:51 +00:00
parent d1c354eac7
commit 62d082634e
7 changed files with 71 additions and 68 deletions

View File

@@ -14,6 +14,13 @@
# limitations under the License.
"""Pytriton server for token2wav conversion and ASR"""
from datasets import load_dataset
from cosyvoice.cli.cosyvoice import CosyVoice2
from omnisense.models import OmniSenseVoiceSmall
from pytriton.proxy.types import Request
from pytriton.triton import Triton, TritonConfig
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.decorators import batch
import argparse
import io
import logging
@@ -37,15 +44,6 @@ zh_tn_model = ZhNormalizer(
overwrite_cache=True,
)
from pytriton.decorators import batch
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.triton import Triton, TritonConfig
from pytriton.proxy.types import Request
from omnisense.models import OmniSenseVoiceSmall
from cosyvoice.cli.cosyvoice import CosyVoice2
from datasets import load_dataset
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
@@ -78,7 +76,6 @@ class _ASR_Server:
return {"TRANSCRIPTS": transcripts}
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
@@ -123,12 +120,12 @@ def get_random_prompt_from_dataset(dataset):
"""
random_idx = random.randint(0, len(dataset) - 1)
sample = dataset[random_idx]
# Extract audio data
audio_data = sample["audio"]
audio_array = audio_data["array"]
sample_rate = audio_data["sampling_rate"]
# Convert audio to 16kHz if needed
if sample_rate != 16000:
num_samples = int(len(audio_array) * (16000 / sample_rate))
@@ -141,6 +138,7 @@ def get_random_prompt_from_dataset(dataset):
prompt_text = prompt_text.replace(" ", "")
return prompt_text, prompt_speech_16k
class _Token2Wav_ASR:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
@@ -163,6 +161,7 @@ class _Token2Wav_ASR:
self.codec_decoder = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
)
@batch
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
"""
@@ -236,7 +235,6 @@ class _Token2Wav_ASR:
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}