mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
|
||||
class CosyVoice2:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
@@ -66,6 +67,7 @@ class CosyVoice2:
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
|
||||
|
||||
class CosyVoice2Model:
|
||||
|
||||
def __init__(self,
|
||||
@@ -109,16 +111,17 @@ class CosyVoice2Model:
|
||||
input_names = ["x", "mask", "mu", "cond"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
|
||||
|
||||
This model takes global and semantic tokens as input and generates audio waveforms
|
||||
using the BiCodec vocoder.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
@@ -126,24 +129,23 @@ class TritonPythonModel:
|
||||
parameters = json.loads(args['model_config'])['parameters']
|
||||
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
||||
model_dir = model_params["model_dir"]
|
||||
|
||||
|
||||
# Initialize device and vocoder
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||
|
||||
|
||||
self.token2wav_model = CosyVoice2(
|
||||
model_dir, load_jit=True, load_trt=True, fp16=True
|
||||
)
|
||||
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
|
||||
Returns:
|
||||
List of inference responses containing generated waveforms
|
||||
"""
|
||||
@@ -163,7 +165,7 @@ class TritonPythonModel:
|
||||
# shift the speech tokens according to the original vocab size
|
||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
|
||||
|
||||
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
||||
token=target_speech_tokens,
|
||||
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
||||
@@ -189,9 +191,5 @@ class TritonPythonModel:
|
||||
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
||||
responses.append(inference_response)
|
||||
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user