mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -14,6 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""Pytriton server for token2wav conversion and ASR"""
|
||||
|
||||
from datasets import load_dataset
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from omnisense.models import OmniSenseVoiceSmall
|
||||
from pytriton.proxy.types import Request
|
||||
from pytriton.triton import Triton, TritonConfig
|
||||
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||
from pytriton.decorators import batch
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
@@ -37,15 +44,6 @@ zh_tn_model = ZhNormalizer(
|
||||
overwrite_cache=True,
|
||||
)
|
||||
|
||||
from pytriton.decorators import batch
|
||||
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||
from pytriton.triton import Triton, TritonConfig
|
||||
from pytriton.proxy.types import Request
|
||||
|
||||
from omnisense.models import OmniSenseVoiceSmall
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
@@ -78,7 +76,6 @@ class _ASR_Server:
|
||||
return {"TRANSCRIPTS": transcripts}
|
||||
|
||||
|
||||
|
||||
def audio_decode_cosyvoice2(
|
||||
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||
):
|
||||
@@ -123,12 +120,12 @@ def get_random_prompt_from_dataset(dataset):
|
||||
"""
|
||||
random_idx = random.randint(0, len(dataset) - 1)
|
||||
sample = dataset[random_idx]
|
||||
|
||||
|
||||
# Extract audio data
|
||||
audio_data = sample["audio"]
|
||||
audio_array = audio_data["array"]
|
||||
sample_rate = audio_data["sampling_rate"]
|
||||
|
||||
|
||||
# Convert audio to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
num_samples = int(len(audio_array) * (16000 / sample_rate))
|
||||
@@ -141,6 +138,7 @@ def get_random_prompt_from_dataset(dataset):
|
||||
prompt_text = prompt_text.replace(" ", "")
|
||||
return prompt_text, prompt_speech_16k
|
||||
|
||||
|
||||
class _Token2Wav_ASR:
|
||||
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||
|
||||
@@ -163,6 +161,7 @@ class _Token2Wav_ASR:
|
||||
self.codec_decoder = CosyVoice2(
|
||||
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
|
||||
)
|
||||
|
||||
@batch
|
||||
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
|
||||
"""
|
||||
@@ -236,7 +235,6 @@ class _Token2Wav_ASR:
|
||||
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
|
||||
|
||||
|
||||
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user