use automodel

This commit is contained in:
lyuxiang.lx
2025-12-09 15:15:05 +00:00
parent 56d9876037
commit 0c65d3c7ab
8 changed files with 56 additions and 88 deletions

View File

@@ -22,8 +22,8 @@ import random
import librosa
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav, logging
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.common import set_all_random_seed
inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
@@ -42,23 +42,9 @@ def generate_seed():
"value": seed
}
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
speech, _ = librosa.effects.trim(
speech, top_db=top_db,
frame_length=win_length,
hop_length=hop_length
)
if speech.abs().max() > max_val:
speech = speech / speech.abs().max() * max_val
speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
return speech
def change_instruction(mode_checkbox_group):
return instruct_dict[mode_checkbox_group]
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
seed, stream, speed):
if prompt_wav_upload is not None:
@@ -118,15 +104,13 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
elif mode_checkbox_group == '3s极速复刻':
logging.info('get zero_shot inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_wav, stream=stream, speed=speed):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
elif mode_checkbox_group == '跨语种复刻':
logging.info('get cross_lingual inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_wav, stream=stream, speed=speed):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
else:
logging.info('get instruct inference request')
@@ -181,16 +165,10 @@ if __name__ == '__main__':
default=8000)
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/CosyVoice2-0.5B',
default='pretrained_models/CosyVoice3-0.5B',
help='local path or modelscope repo id')
args = parser.parse_args()
try:
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
model = AutoModel(model_dir=args.model_dir)
sft_spk = cosyvoice.list_available_spks()
if len(sft_spk) == 0: