update model inference

This commit is contained in:
lyuxiang.lx
2024-07-24 19:18:09 +08:00
parent a13411c561
commit 02f941d348
5 changed files with 85 additions and 64 deletions

View File

@@ -24,14 +24,8 @@ import torchaudio
import random
import librosa
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav, speed_change
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
from cosyvoice.utils.file_utils import load_wav, speed_change, logging
def generate_seed():
seed = random.randint(1, 100000000)
@@ -63,10 +57,11 @@ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成
'3s极速复刻': '1. 选择prompt音频文件或录入prompt音频注意不超过30s若同时提供优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
'跨语种复刻': '1. 选择prompt音频文件或录入prompt音频注意不超过30s若同时提供优先选择prompt音频文件\n2. 点击生成音频按钮',
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
stream_mode_list = [('', False), ('', True)]
def change_instruction(mode_checkbox_group):
return instruct_dict[mode_checkbox_group]
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor):
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor):
if prompt_wav_upload is not None:
prompt_wav = prompt_wav_upload
elif prompt_wav_record is not None:
@@ -117,32 +112,25 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
if mode_checkbox_group == '预训练音色':
logging.info('get sft inference request')
set_all_random_seed(seed)
output = cosyvoice.inference_sft(tts_text, sft_dropdown)
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
yield (target_sr, i['tts_speech'].numpy().flatten())
elif mode_checkbox_group == '3s极速复刻':
logging.info('get zero_shot inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
yield (target_sr, i['tts_speech'].numpy().flatten())
elif mode_checkbox_group == '跨语种复刻':
logging.info('get cross_lingual inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
yield (target_sr, i['tts_speech'].numpy().flatten())
else:
logging.info('get instruct inference request')
set_all_random_seed(seed)
output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text)
if speed_factor != 1.0:
try:
audio_data, sample_rate = speed_change(output["tts_speech"], target_sr, str(speed_factor))
audio_data = audio_data.numpy().flatten()
except Exception as e:
print(f"Failed to change speed of audio: \n{e}")
else:
audio_data = output['tts_speech'].numpy().flatten()
return (target_sr, audio_data)
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
yield (target_sr, i['tts_speech'].numpy().flatten())
def main():
with gr.Blocks() as demo:
@@ -155,6 +143,7 @@ def main():
mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
with gr.Column(scale=0.25):
seed_button = gr.Button(value="\U0001F3B2")
seed = gr.Number(value=0, label="随机推理种子")
@@ -167,11 +156,11 @@ def main():
generate_button = gr.Button("生成音频")
audio_output = gr.Audio(label="合成音频")
audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
seed_button.click(generate_seed, inputs=[], outputs=seed)
generate_button.click(generate_audio,
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor],
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor],
outputs=[audio_output])
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
demo.queue(max_size=4, default_concurrency_limit=2)