mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
clean code
This commit is contained in:
@@ -692,7 +692,7 @@ async def main():
|
|||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
audio_save_dir=args.log_dir,
|
audio_save_dir=args.log_dir,
|
||||||
padding_duration=10,
|
padding_duration=10,
|
||||||
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||||
chunk_overlap_duration=args.chunk_overlap_duration,
|
chunk_overlap_duration=args.chunk_overlap_duration,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -162,8 +162,8 @@ if __name__ == "__main__":
|
|||||||
result = rsp.json()
|
result = rsp.json()
|
||||||
audio = result["outputs"][0]["data"]
|
audio = result["outputs"][0]["data"]
|
||||||
audio = np.array(audio, dtype=np.float32)
|
audio = np.array(audio, dtype=np.float32)
|
||||||
if args.model_name == "cosyvoice2":
|
if args.model_name == "spark_tts":
|
||||||
sample_rate = 24000
|
|
||||||
else:
|
|
||||||
sample_rate = 16000
|
sample_rate = 16000
|
||||||
|
else:
|
||||||
|
sample_rate = 24000
|
||||||
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
|
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
|
||||||
@@ -33,6 +33,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import s3tokenizer
|
import s3tokenizer
|
||||||
|
|
||||||
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for audio tokenization.
|
"""Triton Python model for audio tokenization.
|
||||||
@@ -81,7 +82,7 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
mels, mels_lens = s3tokenizer.padding(mels)
|
mels, mels_lens = s3tokenizer.padding(mels)
|
||||||
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
|
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
|
||||||
codes = codes.clone() + 151663
|
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
|
||||||
|
|
||||||
responses = []
|
responses = []
|
||||||
for i in range(len(requests)):
|
for i in range(len(requests)):
|
||||||
|
|||||||
@@ -199,8 +199,6 @@ class TritonPythonModel:
|
|||||||
Returns:
|
Returns:
|
||||||
Generated waveform tensor
|
Generated waveform tensor
|
||||||
"""
|
"""
|
||||||
print(prompt_speech_tokens.shape, prompt_speech_feat.shape, prompt_spk_embedding.shape, target_speech_tokens.shape)
|
|
||||||
# Convert tensors to Triton format
|
|
||||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||||
@@ -228,9 +226,7 @@ class TritonPythonModel:
|
|||||||
prompt = self.prompt_template.format(input_text=total_text)
|
prompt = self.prompt_template.format(input_text=total_text)
|
||||||
input_ids = self.tokenizer.encode(prompt)
|
input_ids = self.tokenizer.encode(prompt)
|
||||||
input_ids = torch.tensor([input_ids], dtype=torch.int32)
|
input_ids = torch.tensor([input_ids], dtype=torch.int32)
|
||||||
print(input_ids.shape, "before cat")
|
|
||||||
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
||||||
print(input_ids.shape, "after cat", prompt_speech_tokens.shape)
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def _extract_spk_embedding(self, speech):
|
def _extract_spk_embedding(self, speech):
|
||||||
@@ -271,23 +267,15 @@ class TritonPythonModel:
|
|||||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||||
|
|
||||||
# TODO: FIX ME
|
|
||||||
wav_tensor = wav.as_numpy()
|
wav_tensor = wav.as_numpy()
|
||||||
print(wav_tensor.shape, "wav_tensor")
|
|
||||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||||
print(wav_tensor.shape, "wav_tensor after")
|
|
||||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||||
print(speech_feat.shape, "speech_feat")
|
|
||||||
print(prompt_speech_tokens.shape, "prompt_speech_tokens here")
|
|
||||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||||
print(prompt_speech_tokens.shape, "prompt_speech_tokens after")
|
|
||||||
print(speech_feat.shape, "speech_feat after")
|
|
||||||
print(token_len, "token_len")
|
|
||||||
|
|
||||||
# Extract text inputs
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
|
|||||||
@@ -38,13 +38,11 @@ import triton_python_backend_utils as pb_utils
|
|||||||
from hyperpyyaml import load_hyperpyyaml
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||||
from cosyvoice.utils.common import TrtContextWrapper
|
from cosyvoice.utils.common import TrtContextWrapper
|
||||||
#import sys
|
|
||||||
#sys.path.append("/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS")
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
|
||||||
class CosyVoice2:
|
class CosyVoice2:
|
||||||
|
|
||||||
@@ -162,8 +160,9 @@ class TritonPythonModel:
|
|||||||
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
||||||
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
||||||
|
|
||||||
prompt_speech_tokens = prompt_speech_tokens - 151663
|
# shift the speech tokens according to the original vocab size
|
||||||
target_speech_tokens = target_speech_tokens - 151663
|
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
|
||||||
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
||||||
token=target_speech_tokens,
|
token=target_speech_tokens,
|
||||||
|
|||||||
@@ -1,8 +1,4 @@
|
|||||||
# huggingface-cli download --local-dir cosyvoice2_llm yuekai/cosyvoice2_llm
|
|
||||||
# modelscope download --model iic/CosyVoice2-0.5B --local_dir ./CosyVoice2-0.5B/
|
|
||||||
# git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
|
||||||
# cd CosyVoice
|
|
||||||
# git submodule update --init --recursive
|
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice:$PYTHONPATH
|
export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice:$PYTHONPATH
|
||||||
export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS:$PYTHONPATH
|
export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS:$PYTHONPATH
|
||||||
@@ -12,11 +8,21 @@ stop_stage=$2
|
|||||||
huggingface_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/cosyvoice2_llm
|
huggingface_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/cosyvoice2_llm
|
||||||
model_scope_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice2-0.5B
|
model_scope_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice2-0.5B
|
||||||
trt_dtype=bfloat16
|
trt_dtype=bfloat16
|
||||||
trt_dtype=float16
|
|
||||||
trt_weights_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_weights_${trt_dtype}
|
trt_weights_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_weights_${trt_dtype}
|
||||||
trt_engines_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_engines_${trt_dtype}
|
trt_engines_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_engines_${trt_dtype}
|
||||||
|
|
||||||
model_repo=./model_repo_cosyvoice2
|
model_repo=./model_repo_cosyvoice2
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
echo " "
|
||||||
|
huggingface-cli download --local-dir cosyvoice2_llm yuekai/cosyvoice2_llm
|
||||||
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir ./CosyVoice2-0.5B/
|
||||||
|
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||||
|
cd CosyVoice
|
||||||
|
git submodule update --init --recursive
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
echo "Converting checkpoint to TensorRT weights"
|
echo "Converting checkpoint to TensorRT weights"
|
||||||
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
|
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
|
||||||
|
|||||||
Reference in New Issue
Block a user