mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
194 lines
9.3 KiB
Python
194 lines
9.3 KiB
Python
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions
|
|
# are met:
|
|
# * Redistributions of source code must retain the above copyright
|
|
# notice, this list of conditions and the following disclaimer.
|
|
# * Redistributions in binary form must reproduce the above copyright
|
|
# notice, this list of conditions and the following disclaimer in the
|
|
# documentation and/or other materials provided with the distribution.
|
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
|
# contributors may be used to endorse or promote products derived
|
|
# from this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import json
|
|
import os
|
|
|
|
import logging
|
|
from typing import List, Dict
|
|
|
|
import torch
|
|
from torch.utils.dlpack import to_dlpack
|
|
from torch.nn import functional as F
|
|
|
|
import triton_python_backend_utils as pb_utils
|
|
|
|
from hyperpyyaml import load_hyperpyyaml
|
|
from cosyvoice.utils.common import fade_in_out
|
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
|
from cosyvoice.utils.common import TrtContextWrapper
|
|
from collections import defaultdict
|
|
import numpy as np
|
|
from .token2wav_dit import CosyVoice2_Token2Wav
|
|
import hashlib
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ORIGINAL_VOCAB_SIZE = 151663
|
|
torch.set_num_threads(1)
|
|
|
|
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
|
"""
|
|
Generates a unique ID for a torch.Tensor.
|
|
Tensors with the same elements and properties will have the same ID.
|
|
"""
|
|
# Convert tensor to a byte string
|
|
tensor_bytes = tensor.numpy().tobytes()
|
|
|
|
# Create a SHA-256 hash of the byte string
|
|
hasher = hashlib.sha256()
|
|
hasher.update(tensor_bytes)
|
|
|
|
return hasher.hexdigest()
|
|
|
|
class TritonPythonModel:
|
|
"""Triton Python model for vocoder.
|
|
|
|
This model takes global and semantic tokens as input and generates audio waveforms
|
|
using the BiCodec vocoder.
|
|
"""
|
|
|
|
def initialize(self, args):
|
|
"""Initialize the model.
|
|
|
|
Args:
|
|
args: Dictionary containing model configuration
|
|
"""
|
|
# Parse model parameters
|
|
parameters = json.loads(args['model_config'])['parameters']
|
|
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
|
model_dir = model_params["model_dir"]
|
|
|
|
# Initialize device and vocoder
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
|
|
|
# FIXME: device id settings
|
|
self.token2wav_model = CosyVoice2_Token2Wav(
|
|
model_dir, enable_trt=True, streaming=True
|
|
)
|
|
logger.info("Token2Wav initialized successfully")
|
|
|
|
def execute(self, requests):
|
|
"""Execute inference on the batched requests.
|
|
|
|
Args:
|
|
requests: List of inference requests
|
|
|
|
Returns:
|
|
List of inference responses containing generated waveforms
|
|
"""
|
|
responses = []
|
|
for request in requests:
|
|
request_id = request.request_id()
|
|
|
|
# Get inputs
|
|
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens")
|
|
target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack())
|
|
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
|
|
|
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
|
wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy()
|
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item()
|
|
wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0)
|
|
spk_id = get_spk_id_from_prompt_audio(wav)
|
|
|
|
# Handle cache
|
|
conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache")
|
|
if conformer_cnn_cache is not None:
|
|
self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack())
|
|
|
|
conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache")
|
|
self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1)
|
|
|
|
estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache")
|
|
self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0)
|
|
|
|
estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache")
|
|
self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0)
|
|
|
|
mel_np = pb_utils.get_input_tensor_by_name(request, "mel")
|
|
self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack())
|
|
|
|
source_np = pb_utils.get_input_tensor_by_name(request, "source")
|
|
self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack())
|
|
|
|
speech_np = pb_utils.get_input_tensor_by_name(request, "speech")
|
|
self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack())
|
|
|
|
# Forward pass
|
|
audio_hat = self.token2wav_model.forward_streaming(
|
|
target_speech_tokens,
|
|
finalize,
|
|
request_id=request_id,
|
|
speaker_id=f"{spk_id}",
|
|
prompt_audio=wav,
|
|
prompt_audio_sample_rate=16000
|
|
)
|
|
|
|
# Prepare outputs
|
|
outputs = []
|
|
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
|
outputs.append(wav_tensor)
|
|
|
|
if request_id in self.token2wav_model.streaming_flow_cache:
|
|
cache = self.token2wav_model.streaming_flow_cache[request_id]
|
|
hifigan_cache = self.token2wav_model.hift_cache_dict[request_id]
|
|
conformer_cnn_cache = cache['conformer_cnn_cache']
|
|
conformer_att_cache = cache['conformer_att_cache'].transpose(0,1)
|
|
estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0)
|
|
estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0)
|
|
mel = hifigan_cache['mel']
|
|
source = hifigan_cache['source']
|
|
speech = hifigan_cache['speech']
|
|
|
|
outputs.extend([
|
|
pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)),
|
|
pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)),
|
|
pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)),
|
|
pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)),
|
|
pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)),
|
|
pb_utils.Tensor.from_dlpack("source", to_dlpack(source)),
|
|
pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)),
|
|
])
|
|
else:
|
|
outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)),
|
|
pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)),
|
|
pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)),
|
|
pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)),
|
|
pb_utils.Tensor("mel", np.array([], dtype=np.float32)),
|
|
pb_utils.Tensor("source", np.array([], dtype=np.float32)),
|
|
pb_utils.Tensor("speech", np.array([], dtype=np.float32)),
|
|
])
|
|
|
|
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
|
|
responses.append(inference_response)
|
|
return responses
|
|
|
|
def finalize(self):
|
|
self.logger.log_info("Finalizing Token2WavDiT model")
|