mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -48,9 +48,11 @@ import hashlib
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
"""
|
||||
Generates a unique ID for a torch.Tensor.
|
||||
@@ -65,6 +67,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
|
||||
@@ -114,7 +117,6 @@ class TritonPythonModel:
|
||||
|
||||
request_id = request.request_id()
|
||||
|
||||
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(
|
||||
@@ -125,7 +127,10 @@ class TritonPythonModel:
|
||||
|
||||
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||
|
||||
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
|
||||
audio_hat = self.token2wav_model.forward_streaming(
|
||||
target_speech_tokens, finalize, request_id=request_id,
|
||||
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user