mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
Merge pull request #327 from FunAudioLLM/inference_streaming
Inference streaming
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -43,6 +43,8 @@ compile_commands.json
|
|||||||
|
|
||||||
# train/inference files
|
# train/inference files
|
||||||
*.wav
|
*.wav
|
||||||
|
*.m4a
|
||||||
|
*.aac
|
||||||
*.pt
|
*.pt
|
||||||
pretrained_models/*
|
pretrained_models/*
|
||||||
*_pb2_grpc.py
|
*_pb2_grpc.py
|
||||||
|
|||||||
21
README.md
21
README.md
@@ -116,23 +116,24 @@ import torchaudio
|
|||||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
|
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
|
||||||
# sft usage
|
# sft usage
|
||||||
print(cosyvoice.list_avaliable_spks())
|
print(cosyvoice.list_avaliable_spks())
|
||||||
output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女')
|
# change stream=True for chunk stream inference
|
||||||
torchaudio.save('sft.wav', output['tts_speech'], 22050)
|
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
||||||
|
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||||
|
|
||||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
|
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
|
||||||
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
||||||
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
||||||
output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
|
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||||
torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
|
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||||
# cross_lingual usage
|
# cross_lingual usage
|
||||||
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
||||||
output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
|
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
|
||||||
torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
|
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||||
|
|
||||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
||||||
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
||||||
output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
|
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
|
||||||
torchaudio.save('instruct.wav', output['tts_speech'], 22050)
|
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Start web demo**
|
**Start web demo**
|
||||||
@@ -163,10 +164,10 @@ docker build -t cosyvoice:v1.0 .
|
|||||||
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
||||||
# for grpc usage
|
# for grpc usage
|
||||||
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
||||||
python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||||
# for fastapi usage
|
# for fastapi usage
|
||||||
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
|
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
|
||||||
python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||||
```
|
```
|
||||||
|
|
||||||
## Discussion & Communication
|
## Discussion & Communication
|
||||||
|
|||||||
64
cosyvoice/bin/export_jit.py
Normal file
64
cosyvoice/bin/export_jit.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||||
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
|
import torch
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description='export your model for deployment')
|
||||||
|
parser.add_argument('--model_dir',
|
||||||
|
type=str,
|
||||||
|
default='pretrained_models/CosyVoice-300M',
|
||||||
|
help='local path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
|
||||||
|
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
|
||||||
|
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
|
||||||
|
|
||||||
|
# 1. export llm text_encoder
|
||||||
|
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
||||||
|
script = torch.jit.script(llm_text_encoder)
|
||||||
|
script = torch.jit.freeze(script)
|
||||||
|
script = torch.jit.optimize_for_inference(script)
|
||||||
|
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
||||||
|
|
||||||
|
# 2. export llm llm
|
||||||
|
llm_llm = cosyvoice.model.llm.llm.half()
|
||||||
|
script = torch.jit.script(llm_llm)
|
||||||
|
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
|
||||||
|
script = torch.jit.optimize_for_inference(script)
|
||||||
|
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
8
cosyvoice/bin/export_trt.py
Normal file
8
cosyvoice/bin/export_trt.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
|
||||||
|
# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
|
||||||
|
try:
|
||||||
|
import tensorrt
|
||||||
|
except ImportError:
|
||||||
|
print('step1, 下载\n step2. 解压,安装whl,')
|
||||||
|
# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
|
||||||
|
# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
|
||||||
@@ -100,10 +100,13 @@ def main():
|
|||||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
||||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
||||||
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
||||||
model_output = model.inference(**model_input)
|
tts_speeches = []
|
||||||
|
for model_output in model.inference(**model_input):
|
||||||
|
tts_speeches.append(model_output['tts_speech'])
|
||||||
|
tts_speeches = torch.concat(tts_speeches, dim=1)
|
||||||
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
||||||
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
||||||
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
|
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
||||||
f.write('{} {}\n'.format(tts_key, tts_fn))
|
f.write('{} {}\n'.format(tts_key, tts_fn))
|
||||||
f.flush()
|
f.flush()
|
||||||
f.close()
|
f.close()
|
||||||
|
|||||||
@@ -12,15 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import torch
|
import time
|
||||||
from hyperpyyaml import load_hyperpyyaml
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
||||||
from cosyvoice.cli.model import CosyVoiceModel
|
from cosyvoice.cli.model import CosyVoiceModel
|
||||||
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
|
||||||
class CosyVoice:
|
class CosyVoice:
|
||||||
|
|
||||||
def __init__(self, model_dir):
|
def __init__(self, model_dir, load_jit=True):
|
||||||
instruct = True if '-Instruct' in model_dir else False
|
instruct = True if '-Instruct' in model_dir else False
|
||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
if not os.path.exists(model_dir):
|
if not os.path.exists(model_dir):
|
||||||
@@ -38,46 +39,61 @@ class CosyVoice:
|
|||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
|
if load_jit:
|
||||||
|
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
||||||
|
'{}/llm.llm.fp16.zip'.format(model_dir))
|
||||||
del configs
|
del configs
|
||||||
|
|
||||||
def list_avaliable_spks(self):
|
def list_avaliable_spks(self):
|
||||||
spks = list(self.frontend.spk2info.keys())
|
spks = list(self.frontend.spk2info.keys())
|
||||||
return spks
|
return spks
|
||||||
|
|
||||||
def inference_sft(self, tts_text, spk_id):
|
def inference_sft(self, tts_text, spk_id, stream=False):
|
||||||
tts_speeches = []
|
|
||||||
for i in self.frontend.text_normalize(tts_text, split=True):
|
for i in self.frontend.text_normalize(tts_text, split=True):
|
||||||
model_input = self.frontend.frontend_sft(i, spk_id)
|
model_input = self.frontend.frontend_sft(i, spk_id)
|
||||||
model_output = self.model.inference(**model_input)
|
start_time = time.time()
|
||||||
tts_speeches.append(model_output['tts_speech'])
|
logging.info('synthesis text {}'.format(i))
|
||||||
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
for model_output in self.model.inference(**model_input, stream=stream):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
|
||||||
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
||||||
tts_speeches = []
|
|
||||||
for i in self.frontend.text_normalize(tts_text, split=True):
|
for i in self.frontend.text_normalize(tts_text, split=True):
|
||||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
||||||
model_output = self.model.inference(**model_input)
|
start_time = time.time()
|
||||||
tts_speeches.append(model_output['tts_speech'])
|
logging.info('synthesis text {}'.format(i))
|
||||||
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
for model_output in self.model.inference(**model_input, stream=stream):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
def inference_cross_lingual(self, tts_text, prompt_speech_16k):
|
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
|
||||||
if self.frontend.instruct is True:
|
if self.frontend.instruct is True:
|
||||||
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
||||||
tts_speeches = []
|
|
||||||
for i in self.frontend.text_normalize(tts_text, split=True):
|
for i in self.frontend.text_normalize(tts_text, split=True):
|
||||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
||||||
model_output = self.model.inference(**model_input)
|
start_time = time.time()
|
||||||
tts_speeches.append(model_output['tts_speech'])
|
logging.info('synthesis text {}'.format(i))
|
||||||
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
for model_output in self.model.inference(**model_input, stream=stream):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
def inference_instruct(self, tts_text, spk_id, instruct_text):
|
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
|
||||||
if self.frontend.instruct is False:
|
if self.frontend.instruct is False:
|
||||||
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
||||||
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
||||||
tts_speeches = []
|
|
||||||
for i in self.frontend.text_normalize(tts_text, split=True):
|
for i in self.frontend.text_normalize(tts_text, split=True):
|
||||||
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
||||||
model_output = self.model.inference(**model_input)
|
start_time = time.time()
|
||||||
tts_speeches.append(model_output['tts_speech'])
|
logging.info('synthesis text {}'.format(i))
|
||||||
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
for model_output in self.model.inference(**model_input, stream=stream):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|||||||
@@ -12,6 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
import uuid
|
||||||
|
from cosyvoice.utils.common import fade_in_out
|
||||||
|
|
||||||
|
|
||||||
class CosyVoiceModel:
|
class CosyVoiceModel:
|
||||||
|
|
||||||
@@ -23,38 +30,143 @@ class CosyVoiceModel:
|
|||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
|
self.token_min_hop_len = 100
|
||||||
|
self.token_max_hop_len = 200
|
||||||
|
self.token_overlap_len = 20
|
||||||
|
# mel fade in out
|
||||||
|
self.mel_overlap_len = 34
|
||||||
|
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
||||||
|
# hift cache
|
||||||
|
self.mel_cache_len = 20
|
||||||
|
self.source_cache_len = int(self.mel_cache_len * 256)
|
||||||
|
# rtf and decoding related
|
||||||
|
self.stream_scale_factor = 1
|
||||||
|
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||||
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||||
|
self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
# dict used to store session related variable
|
||||||
|
self.tts_speech_token_dict = {}
|
||||||
|
self.llm_end_dict = {}
|
||||||
|
self.mel_overlap_dict = {}
|
||||||
|
self.hift_cache_dict = {}
|
||||||
|
|
||||||
def load(self, llm_model, flow_model, hift_model):
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
||||||
self.llm.to(self.device).eval()
|
self.llm.to(self.device).eval()
|
||||||
|
self.llm.half()
|
||||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
||||||
self.flow.to(self.device).eval()
|
self.flow.to(self.device).eval()
|
||||||
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
||||||
self.hift.to(self.device).eval()
|
self.hift.to(self.device).eval()
|
||||||
|
|
||||||
def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
def load_jit(self, llm_text_encoder_model, llm_llm_model):
|
||||||
prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
|
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
|
||||||
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
self.llm.text_encoder = llm_text_encoder
|
||||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
llm_llm = torch.jit.load(llm_llm_model)
|
||||||
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
|
self.llm.llm = llm_llm
|
||||||
tts_speech_token = self.llm.inference(text=text.to(self.device),
|
|
||||||
text_len=text_len.to(self.device),
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
prompt_text=prompt_text.to(self.device),
|
with self.llm_context:
|
||||||
prompt_text_len=prompt_text_len.to(self.device),
|
for i in self.llm.inference(text=text.to(self.device),
|
||||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
|
prompt_text=prompt_text.to(self.device),
|
||||||
embedding=llm_embedding.to(self.device),
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
beam_size=1,
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
sampling=25,
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
max_token_text_ratio=30,
|
embedding=llm_embedding.to(self.device).half(),
|
||||||
min_token_text_ratio=3)
|
sampling=25,
|
||||||
tts_mel = self.flow.inference(token=tts_speech_token,
|
max_token_text_ratio=30,
|
||||||
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
|
min_token_text_ratio=3):
|
||||||
prompt_token=flow_prompt_speech_token.to(self.device),
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
|
self.llm_end_dict[uuid] = True
|
||||||
prompt_feat=prompt_speech_feat.to(self.device),
|
|
||||||
prompt_feat_len=prompt_speech_feat_len.to(self.device),
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
|
||||||
embedding=flow_embedding.to(self.device))
|
with self.flow_hift_context:
|
||||||
tts_speech = self.hift.inference(mel=tts_mel).cpu()
|
tts_mel = self.flow.inference(token=token.to(self.device),
|
||||||
torch.cuda.empty_cache()
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
return {'tts_speech': tts_speech}
|
prompt_token=prompt_token.to(self.device),
|
||||||
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=embedding.to(self.device))
|
||||||
|
# mel overlap fade in out
|
||||||
|
if self.mel_overlap_dict[uuid] is not None:
|
||||||
|
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
||||||
|
# append hift cache
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||||
|
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||||
|
else:
|
||||||
|
hift_cache_source = torch.zeros(1, 1, 0)
|
||||||
|
# keep overlap mel and hift cache
|
||||||
|
if finalize is False:
|
||||||
|
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
||||||
|
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
||||||
|
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
||||||
|
self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
|
||||||
|
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||||
|
else:
|
||||||
|
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
|
||||||
|
return tts_speech
|
||||||
|
|
||||||
|
def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
||||||
|
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
|
||||||
|
# this_uuid is used to track variables related to this inference thread
|
||||||
|
this_uuid = str(uuid.uuid1())
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
|
||||||
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||||
|
p.start()
|
||||||
|
if stream is True:
|
||||||
|
token_hop_len = self.token_min_hop_len
|
||||||
|
while True:
|
||||||
|
time.sleep(0.1)
|
||||||
|
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||||
|
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
|
||||||
|
with self.flow_hift_context:
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=False)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||||
|
# increase token_hop_len for better speech quality
|
||||||
|
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||||
|
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||||
|
break
|
||||||
|
p.join()
|
||||||
|
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||||
|
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
||||||
|
with self.flow_hift_context:
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=True)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
else:
|
||||||
|
# deal with all tokens
|
||||||
|
p.join()
|
||||||
|
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
||||||
|
with self.flow_hift_context:
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=True)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict.pop(this_uuid)
|
||||||
|
self.llm_end_dict.pop(this_uuid)
|
||||||
|
self.mel_overlap_dict.pop(this_uuid)
|
||||||
|
self.hift_cache_dict.pop(this_uuid)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
embedding = self.spk_embed_affine_layer(embedding)
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
|
|
||||||
# concat text and prompt_text
|
# concat text and prompt_text
|
||||||
|
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
|
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
|
||||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
@@ -118,17 +119,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
# text encode
|
# text encode
|
||||||
h, h_lengths = self.encoder(token, token_len)
|
h, h_lengths = self.encoder(token, token_len)
|
||||||
h = self.encoder_proj(h)
|
h = self.encoder_proj(h)
|
||||||
feat_len = (token_len / 50 * 22050 / 256).int()
|
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
|
||||||
h, h_lengths = self.length_regulator(h, feat_len)
|
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
|
||||||
|
|
||||||
# get conditions
|
# get conditions
|
||||||
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
|
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
||||||
if prompt_feat.shape[1] != 0:
|
conds[:, :mel_len1] = prompt_feat
|
||||||
for i, j in enumerate(prompt_feat_len):
|
|
||||||
conds[i, :j] = prompt_feat[i]
|
|
||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
mask = (~make_pad_mask(feat_len)).to(h)
|
# mask = (~make_pad_mask(feat_len)).to(h)
|
||||||
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||||
feat = self.decoder(
|
feat = self.decoder(
|
||||||
mu=h.transpose(1, 2).contiguous(),
|
mu=h.transpose(1, 2).contiguous(),
|
||||||
mask=mask.unsqueeze(1),
|
mask=mask.unsqueeze(1),
|
||||||
@@ -136,6 +136,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
cond=conds,
|
cond=conds,
|
||||||
n_timesteps=10
|
n_timesteps=10
|
||||||
)
|
)
|
||||||
if prompt_feat.shape[1] != 0:
|
feat = feat[:, :, mel_len1:]
|
||||||
feat = feat[:, :, prompt_feat.shape[1]:]
|
assert feat.shape[2] == mel_len2
|
||||||
return feat
|
return feat
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from cosyvoice.utils.mask import make_pad_mask
|
from cosyvoice.utils.mask import make_pad_mask
|
||||||
|
|
||||||
@@ -43,7 +44,25 @@ class InterpolateRegulator(nn.Module):
|
|||||||
def forward(self, x, ylens=None):
|
def forward(self, x, ylens=None):
|
||||||
# x in (B, T, D)
|
# x in (B, T, D)
|
||||||
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
||||||
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
||||||
out = self.model(x).transpose(1, 2).contiguous()
|
out = self.model(x).transpose(1, 2).contiguous()
|
||||||
olens = ylens
|
olens = ylens
|
||||||
return out * mask, olens
|
return out * mask, olens
|
||||||
|
|
||||||
|
def inference(self, x1, x2, mel_len1, mel_len2):
|
||||||
|
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
||||||
|
# x in (B, T, D)
|
||||||
|
if x2.shape[1] > 40:
|
||||||
|
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
|
||||||
|
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
|
||||||
|
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
|
||||||
|
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
||||||
|
else:
|
||||||
|
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
||||||
|
if x1.shape[1] != 0:
|
||||||
|
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
||||||
|
x = torch.concat([x1, x2], dim=2)
|
||||||
|
else:
|
||||||
|
x = x2
|
||||||
|
out = self.model(x).transpose(1, 2).contiguous()
|
||||||
|
return out, mel_len1 + mel_len2
|
||||||
|
|||||||
@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
|
|||||||
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
||||||
return inverse_transform
|
return inverse_transform
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||||
f0 = self.f0_predictor(x)
|
f0 = self.f0_predictor(x)
|
||||||
s = self._f02source(f0)
|
s = self._f02source(f0)
|
||||||
|
|
||||||
|
# use cache_source to avoid glitch
|
||||||
|
if cache_source.shape[2] == 0:
|
||||||
|
s[:, :, :cache_source.shape[2]] = cache_source
|
||||||
|
|
||||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||||
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||||
|
|
||||||
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
|
|||||||
|
|
||||||
x = self._istft(magnitude, phase)
|
x = self._istft(magnitude, phase)
|
||||||
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||||
return x
|
return x, s
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
print('Removing weight norm...')
|
print('Removing weight norm...')
|
||||||
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
|
|||||||
l.remove_weight_norm()
|
l.remove_weight_norm()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(self, mel: torch.Tensor) -> torch.Tensor:
|
def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||||
return self.forward(x=mel)
|
return self.forward(x=mel, cache_source=cache_source)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Callable, List, Generator
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
speech_token_size: int,
|
speech_token_size: int,
|
||||||
text_encoder: torch.nn.Module,
|
text_encoder: torch.nn.Module,
|
||||||
llm: torch.nn.Module,
|
llm: torch.nn.Module,
|
||||||
|
sampling: Callable,
|
||||||
length_normalized_loss: bool = True,
|
length_normalized_loss: bool = True,
|
||||||
lsm_weight: float = 0.0,
|
lsm_weight: float = 0.0,
|
||||||
spk_embed_dim: int = 192,
|
spk_embed_dim: int = 192,
|
||||||
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
|
|||||||
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
||||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
||||||
|
|
||||||
|
# 4. sampling method
|
||||||
|
self.sampling = sampling
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
text: torch.Tensor,
|
text: torch.Tensor,
|
||||||
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
|
|||||||
def sampling_ids(
|
def sampling_ids(
|
||||||
self,
|
self,
|
||||||
weighted_scores: torch.Tensor,
|
weighted_scores: torch.Tensor,
|
||||||
sampling: Union[bool, int, float] = True,
|
decoded_tokens: List,
|
||||||
beam_size: int = 1,
|
sampling: int,
|
||||||
ignore_eos: bool = True,
|
ignore_eos: bool = True,
|
||||||
):
|
):
|
||||||
while True:
|
while True:
|
||||||
prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
|
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||||
top_ids = prob.multinomial(beam_size, replacement=True)
|
|
||||||
top_ids = indices[top_ids]
|
|
||||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
||||||
break
|
break
|
||||||
return top_ids
|
return top_ids
|
||||||
@@ -154,11 +156,10 @@ class TransformerLM(torch.nn.Module):
|
|||||||
prompt_speech_token: torch.Tensor,
|
prompt_speech_token: torch.Tensor,
|
||||||
prompt_speech_token_len: torch.Tensor,
|
prompt_speech_token_len: torch.Tensor,
|
||||||
embedding: torch.Tensor,
|
embedding: torch.Tensor,
|
||||||
beam_size: int = 1,
|
|
||||||
sampling: int = 25,
|
sampling: int = 25,
|
||||||
max_token_text_ratio: float = 20,
|
max_token_text_ratio: float = 20,
|
||||||
min_token_text_ratio: float = 2,
|
min_token_text_ratio: float = 2,
|
||||||
) -> torch.Tensor:
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
device = text.device
|
device = text.device
|
||||||
text = torch.concat([prompt_text, text], dim=1)
|
text = torch.concat([prompt_text, text], dim=1)
|
||||||
text_len += prompt_text_len
|
text_len += prompt_text_len
|
||||||
@@ -173,7 +174,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
embedding = self.spk_embed_affine_layer(embedding)
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
embedding = embedding.unsqueeze(dim=1)
|
embedding = embedding.unsqueeze(dim=1)
|
||||||
else:
|
else:
|
||||||
embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
|
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||||
|
|
||||||
# 3. concat llm_input
|
# 3. concat llm_input
|
||||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||||
@@ -181,7 +182,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
if prompt_speech_token_len != 0:
|
if prompt_speech_token_len != 0:
|
||||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
else:
|
else:
|
||||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||||
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||||
|
|
||||||
# 4. cal min/max_length
|
# 4. cal min/max_length
|
||||||
@@ -196,11 +197,11 @@ class TransformerLM(torch.nn.Module):
|
|||||||
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
|
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
|
||||||
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
|
||||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||||
if top_ids == self.speech_token_size:
|
if top_ids == self.speech_token_size:
|
||||||
break
|
break
|
||||||
|
# in stream mode, yield token one by one
|
||||||
|
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
|
||||||
out_tokens.append(top_ids)
|
out_tokens.append(top_ids)
|
||||||
offset += lm_input.size(1)
|
offset += lm_input.size(1)
|
||||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|
||||||
return torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|||||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||||
|
|
||||||
def rel_shift(self, x):
|
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|||||||
torch.Tensor: Output tensor.
|
torch.Tensor: Output tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype)
|
||||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||||
|
|
||||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
x_padded = x_padded.view(x.size()[0],
|
||||||
|
x.size()[1],
|
||||||
|
x.size(3) + 1, x.size(2))
|
||||||
x = x_padded[:, :, 1:].view_as(x)[
|
x = x_padded[:, :, 1:].view_as(x)[
|
||||||
:, :, :, : x.size(-1) // 2 + 1
|
:, :, :, : x.size(-1) // 2 + 1
|
||||||
] # only keep the positions from 0 to time2
|
] # only keep the positions from 0 to time2
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ class TransformerDecoder(torch.nn.Module):
|
|||||||
memory_mask)
|
memory_mask)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.jit.ignore(drop=True)
|
@torch.jit.unused
|
||||||
def forward_layers_checkpointed(self, x: torch.Tensor,
|
def forward_layers_checkpointed(self, x: torch.Tensor,
|
||||||
tgt_mask: torch.Tensor,
|
tgt_mask: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(EspnetRelPositionalEncoding, self).__init__()
|
super(EspnetRelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@@ -221,7 +221,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
def extend_pe(self, x):
|
def extend_pe(self, x: torch.Tensor):
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
@@ -253,7 +253,8 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|||||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
|
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
||||||
|
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ class BaseEncoder(torch.nn.Module):
|
|||||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
return xs
|
return xs
|
||||||
|
|
||||||
@torch.jit.ignore(drop=True)
|
@torch.jit.unused
|
||||||
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
||||||
chunk_masks: torch.Tensor,
|
chunk_masks: torch.Tensor,
|
||||||
pos_emb: torch.Tensor,
|
pos_emb: torch.Tensor,
|
||||||
@@ -180,6 +180,7 @@ class BaseEncoder(torch.nn.Module):
|
|||||||
mask_pad)
|
mask_pad)
|
||||||
return xs
|
return xs
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def forward_chunk(
|
def forward_chunk(
|
||||||
self,
|
self,
|
||||||
xs: torch.Tensor,
|
xs: torch.Tensor,
|
||||||
@@ -270,6 +271,7 @@ class BaseEncoder(torch.nn.Module):
|
|||||||
|
|
||||||
return (xs, r_att_cache, r_cnn_cache)
|
return (xs, r_att_cache, r_cnn_cache)
|
||||||
|
|
||||||
|
@torch.jit.unused
|
||||||
def forward_chunk_by_chunk(
|
def forward_chunk_by_chunk(
|
||||||
self,
|
self,
|
||||||
xs: torch.Tensor,
|
xs: torch.Tensor,
|
||||||
|
|||||||
@@ -101,3 +101,39 @@ def init_weights(m, mean=0.0, std=0.01):
|
|||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
if classname.find("Conv") != -1:
|
if classname.find("Conv") != -1:
|
||||||
m.weight.data.normal_(mean, std)
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
# Repetition Aware Sampling in VALL-E 2
|
||||||
|
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
||||||
|
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||||
|
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
|
||||||
|
if rep_num >= win_size * tau_r:
|
||||||
|
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||||
|
prob, indices = [], []
|
||||||
|
cum_prob = 0.0
|
||||||
|
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
||||||
|
for i in range(len(sorted_idx)):
|
||||||
|
# sampling both top-p and numbers.
|
||||||
|
if cum_prob < top_p and len(prob) < top_k:
|
||||||
|
cum_prob += sorted_value[i]
|
||||||
|
prob.append(sorted_value[i])
|
||||||
|
indices.append(sorted_idx[i])
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
prob = torch.tensor(prob).to(weighted_scores)
|
||||||
|
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
||||||
|
top_ids = indices[prob.multinomial(1, replacement=True)]
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
||||||
|
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
||||||
|
device = fade_in_mel.device
|
||||||
|
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
||||||
|
mel_overlap_len = int(window.shape[0] / 2)
|
||||||
|
fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
|
||||||
|
return fade_in_mel.to(device)
|
||||||
|
|||||||
@@ -15,6 +15,10 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import logging
|
||||||
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
|
||||||
|
|
||||||
def read_lists(list_file):
|
def read_lists(list_file):
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
|
|||||||
num_blocks: 3
|
num_blocks: 3
|
||||||
dropout_rate: 0.1
|
dropout_rate: 0.1
|
||||||
positional_dropout_rate: 0.1
|
positional_dropout_rate: 0.1
|
||||||
attention_dropout_rate: 0
|
attention_dropout_rate: 0.0
|
||||||
normalize_before: True
|
normalize_before: True
|
||||||
input_layer: 'linear'
|
input_layer: 'linear'
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
pos_enc_layer_type: 'rel_pos_espnet'
|
||||||
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
|
|||||||
num_blocks: 7
|
num_blocks: 7
|
||||||
dropout_rate: 0.1
|
dropout_rate: 0.1
|
||||||
positional_dropout_rate: 0.1
|
positional_dropout_rate: 0.1
|
||||||
attention_dropout_rate: 0
|
attention_dropout_rate: 0.0
|
||||||
input_layer: 'linear_legacy'
|
input_layer: 'linear_legacy'
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
pos_enc_layer_type: 'rel_pos_espnet'
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
selfattention_layer_type: 'rel_selfattn'
|
||||||
static_chunk_size: 1
|
static_chunk_size: 1
|
||||||
|
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 25
|
||||||
|
win_size: 10
|
||||||
|
tau_r: 0.1
|
||||||
|
|
||||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||||
input_size: 512
|
input_size: 512
|
||||||
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
|||||||
in_channels: 320
|
in_channels: 320
|
||||||
out_channels: 80
|
out_channels: 80
|
||||||
channels: [256, 256]
|
channels: [256, 256]
|
||||||
dropout: 0
|
dropout: 0.0
|
||||||
attention_head_dim: 64
|
attention_head_dim: 64
|
||||||
n_blocks: 4
|
n_blocks: 4
|
||||||
num_mid_blocks: 8
|
num_mid_blocks: 8
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
|
|||||||
num_blocks: 6
|
num_blocks: 6
|
||||||
dropout_rate: 0.1
|
dropout_rate: 0.1
|
||||||
positional_dropout_rate: 0.1
|
positional_dropout_rate: 0.1
|
||||||
attention_dropout_rate: 0
|
attention_dropout_rate: 0.0
|
||||||
normalize_before: True
|
normalize_before: True
|
||||||
input_layer: 'linear'
|
input_layer: 'linear'
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
pos_enc_layer_type: 'rel_pos_espnet'
|
||||||
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
|
|||||||
num_blocks: 14
|
num_blocks: 14
|
||||||
dropout_rate: 0.1
|
dropout_rate: 0.1
|
||||||
positional_dropout_rate: 0.1
|
positional_dropout_rate: 0.1
|
||||||
attention_dropout_rate: 0
|
attention_dropout_rate: 0.0
|
||||||
input_layer: 'linear_legacy'
|
input_layer: 'linear_legacy'
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
pos_enc_layer_type: 'rel_pos_espnet'
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
selfattention_layer_type: 'rel_selfattn'
|
||||||
static_chunk_size: 1
|
static_chunk_size: 1
|
||||||
|
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 25
|
||||||
|
win_size: 10
|
||||||
|
tau_r: 0.1
|
||||||
|
|
||||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||||
input_size: 512
|
input_size: 512
|
||||||
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
|||||||
in_channels: 320
|
in_channels: 320
|
||||||
out_channels: 80
|
out_channels: 80
|
||||||
channels: [256, 256]
|
channels: [256, 256]
|
||||||
dropout: 0
|
dropout: 0.0
|
||||||
attention_head_dim: 64
|
attention_head_dim: 64
|
||||||
n_blocks: 4
|
n_blocks: 4
|
||||||
num_mid_blocks: 12
|
num_mid_blocks: 12
|
||||||
|
|||||||
@@ -61,8 +61,11 @@ def main():
|
|||||||
request.instruct_request.CopyFrom(instruct_request)
|
request.instruct_request.CopyFrom(instruct_request)
|
||||||
|
|
||||||
response = stub.Inference(request)
|
response = stub.Inference(request)
|
||||||
|
tts_audio = b''
|
||||||
|
for r in response:
|
||||||
|
tts_audio += r.tts_audio
|
||||||
|
tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
|
||||||
logging.info('save response to {}'.format(args.tts_wav))
|
logging.info('save response to {}'.format(args.tts_wav))
|
||||||
tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
|
|
||||||
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
||||||
logging.info('get response')
|
logging.info('get response')
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ package cosyvoice;
|
|||||||
option go_package = "protos/";
|
option go_package = "protos/";
|
||||||
|
|
||||||
service CosyVoice{
|
service CosyVoice{
|
||||||
rpc Inference(Request) returns (Response) {}
|
rpc Inference(Request) returns (stream Response) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
message Request{
|
message Request{
|
||||||
|
|||||||
@@ -54,9 +54,10 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
|||||||
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
|
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
|
||||||
|
|
||||||
logging.info('send inference response')
|
logging.info('send inference response')
|
||||||
response = cosyvoice_pb2.Response()
|
for i in model_output:
|
||||||
response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
response = cosyvoice_pb2.Response()
|
||||||
return response
|
response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
||||||
|
yield response
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
|
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
|
||||||
|
|||||||
41
webui.py
41
webui.py
@@ -24,14 +24,8 @@ import torchaudio
|
|||||||
import random
|
import random
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
import logging
|
|
||||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||||
from cosyvoice.utils.file_utils import load_wav, speed_change
|
from cosyvoice.utils.file_utils import load_wav, speed_change, logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
|
||||||
format='%(asctime)s %(levelname)s %(message)s')
|
|
||||||
|
|
||||||
def generate_seed():
|
def generate_seed():
|
||||||
seed = random.randint(1, 100000000)
|
seed = random.randint(1, 100000000)
|
||||||
@@ -63,10 +57,11 @@ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成
|
|||||||
'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
|
'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
|
||||||
'跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
|
'跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
|
||||||
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
|
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
|
||||||
|
stream_mode_list = [('否', False), ('是', True)]
|
||||||
def change_instruction(mode_checkbox_group):
|
def change_instruction(mode_checkbox_group):
|
||||||
return instruct_dict[mode_checkbox_group]
|
return instruct_dict[mode_checkbox_group]
|
||||||
|
|
||||||
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor):
|
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor):
|
||||||
if prompt_wav_upload is not None:
|
if prompt_wav_upload is not None:
|
||||||
prompt_wav = prompt_wav_upload
|
prompt_wav = prompt_wav_upload
|
||||||
elif prompt_wav_record is not None:
|
elif prompt_wav_record is not None:
|
||||||
@@ -117,32 +112,25 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
|
|||||||
if mode_checkbox_group == '预训练音色':
|
if mode_checkbox_group == '预训练音色':
|
||||||
logging.info('get sft inference request')
|
logging.info('get sft inference request')
|
||||||
set_all_random_seed(seed)
|
set_all_random_seed(seed)
|
||||||
output = cosyvoice.inference_sft(tts_text, sft_dropdown)
|
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
|
||||||
|
yield (target_sr, i['tts_speech'].numpy().flatten())
|
||||||
elif mode_checkbox_group == '3s极速复刻':
|
elif mode_checkbox_group == '3s极速复刻':
|
||||||
logging.info('get zero_shot inference request')
|
logging.info('get zero_shot inference request')
|
||||||
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
||||||
set_all_random_seed(seed)
|
set_all_random_seed(seed)
|
||||||
output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
|
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
|
||||||
|
yield (target_sr, i['tts_speech'].numpy().flatten())
|
||||||
elif mode_checkbox_group == '跨语种复刻':
|
elif mode_checkbox_group == '跨语种复刻':
|
||||||
logging.info('get cross_lingual inference request')
|
logging.info('get cross_lingual inference request')
|
||||||
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
||||||
set_all_random_seed(seed)
|
set_all_random_seed(seed)
|
||||||
output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
|
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
|
||||||
|
yield (target_sr, i['tts_speech'].numpy().flatten())
|
||||||
else:
|
else:
|
||||||
logging.info('get instruct inference request')
|
logging.info('get instruct inference request')
|
||||||
set_all_random_seed(seed)
|
set_all_random_seed(seed)
|
||||||
output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text)
|
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
|
||||||
|
yield (target_sr, i['tts_speech'].numpy().flatten())
|
||||||
if speed_factor != 1.0:
|
|
||||||
try:
|
|
||||||
audio_data, sample_rate = speed_change(output["tts_speech"], target_sr, str(speed_factor))
|
|
||||||
audio_data = audio_data.numpy().flatten()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to change speed of audio: \n{e}")
|
|
||||||
else:
|
|
||||||
audio_data = output['tts_speech'].numpy().flatten()
|
|
||||||
|
|
||||||
return (target_sr, audio_data)
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@@ -155,6 +143,7 @@ def main():
|
|||||||
mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
|
mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
|
||||||
instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
|
instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
|
||||||
sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
|
sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
|
||||||
|
stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
|
||||||
with gr.Column(scale=0.25):
|
with gr.Column(scale=0.25):
|
||||||
seed_button = gr.Button(value="\U0001F3B2")
|
seed_button = gr.Button(value="\U0001F3B2")
|
||||||
seed = gr.Number(value=0, label="随机推理种子")
|
seed = gr.Number(value=0, label="随机推理种子")
|
||||||
@@ -167,11 +156,11 @@ def main():
|
|||||||
|
|
||||||
generate_button = gr.Button("生成音频")
|
generate_button = gr.Button("生成音频")
|
||||||
|
|
||||||
audio_output = gr.Audio(label="合成音频")
|
audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
|
||||||
|
|
||||||
seed_button.click(generate_seed, inputs=[], outputs=seed)
|
seed_button.click(generate_seed, inputs=[], outputs=seed)
|
||||||
generate_button.click(generate_audio,
|
generate_button.click(generate_audio,
|
||||||
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor],
|
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor],
|
||||||
outputs=[audio_output])
|
outputs=[audio_output])
|
||||||
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
|
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
|
||||||
demo.queue(max_size=4, default_concurrency_limit=2)
|
demo.queue(max_size=4, default_concurrency_limit=2)
|
||||||
@@ -184,7 +173,7 @@ if __name__ == '__main__':
|
|||||||
default=8000)
|
default=8000)
|
||||||
parser.add_argument('--model_dir',
|
parser.add_argument('--model_dir',
|
||||||
type=str,
|
type=str,
|
||||||
default='iic/CosyVoice-300M',
|
default='pretrained_models/CosyVoice-300M',
|
||||||
help='local path or modelscope repo id')
|
help='local path or modelscope repo id')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
cosyvoice = CosyVoice(args.model_dir)
|
cosyvoice = CosyVoice(args.model_dir)
|
||||||
|
|||||||
Reference in New Issue
Block a user