Merge pull request #327 from FunAudioLLM/inference_streaming

Inference streaming
This commit is contained in:
Xiang Lyu
2024-08-29 23:51:21 +08:00
committed by GitHub
23 changed files with 412 additions and 132 deletions

2
.gitignore vendored
View File

@@ -43,6 +43,8 @@ compile_commands.json
# train/inference files # train/inference files
*.wav *.wav
*.m4a
*.aac
*.pt *.pt
pretrained_models/* pretrained_models/*
*_pb2_grpc.py *_pb2_grpc.py

View File

@@ -116,23 +116,24 @@ import torchaudio
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT') cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
# sft usage # sft usage
print(cosyvoice.list_avaliable_spks()) print(cosyvoice.list_avaliable_spks())
output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女') # change stream=True for chunk stream inference
torchaudio.save('sft.wav', output['tts_speech'], 22050) for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000) prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k) for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot.wav', output['tts_speech'], 22050) torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
# cross_lingual usage # cross_lingual usage
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000) prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k) for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050) torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct') cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath] # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.') for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
torchaudio.save('instruct.wav', output['tts_speech'], 22050) torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
``` ```
**Start web demo** **Start web demo**
@@ -163,10 +164,10 @@ docker build -t cosyvoice:v1.0 .
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
# for grpc usage # for grpc usage
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity" docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct> cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
# for fastapi usage # for fastapi usage
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity" docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct> cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
``` ```
## Discussion & Communication ## Discussion & Communication

View File

@@ -0,0 +1,64 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
import torch
from cosyvoice.cli.cosyvoice import CosyVoice
def get_args():
parser = argparse.ArgumentParser(description='export your model for deployment')
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/CosyVoice-300M',
help='local path')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
# 1. export llm text_encoder
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
script = torch.jit.script(llm_text_encoder)
script = torch.jit.freeze(script)
script = torch.jit.optimize_for_inference(script)
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
# 2. export llm llm
llm_llm = cosyvoice.model.llm.llm.half()
script = torch.jit.script(llm_llm)
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
script = torch.jit.optimize_for_inference(script)
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,8 @@
# TODO 跟export_jit一样的逻辑完成flow部分的estimator的onnx导出。
# tensorrt的安装方式再这里写一下步骤提示如下如果没有安装那么不要执行这个脚本提示用户先安装不给选择
try:
import tensorrt
except ImportError:
print('step1, 下载\n step2. 解压安装whl')
# 安装命令里tensosrt的根目录用环境变量导入比如os.environ['tensorrt_root_dir']/bin/exetrace然后python里subprocess里执行导出命令
# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx

View File

@@ -100,10 +100,13 @@ def main():
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
model_output = model.inference(**model_input) tts_speeches = []
for model_output in model.inference(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0]) tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050) torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
f.write('{} {}\n'.format(tts_key, tts_fn)) f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush() f.flush()
f.close() f.close()

View File

@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import torch import time
from hyperpyyaml import load_hyperpyyaml from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download from modelscope import snapshot_download
from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel from cosyvoice.cli.model import CosyVoiceModel
from cosyvoice.utils.file_utils import logging
class CosyVoice: class CosyVoice:
def __init__(self, model_dir): def __init__(self, model_dir, load_jit=True):
instruct = True if '-Instruct' in model_dir else False instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
@@ -38,46 +39,61 @@ class CosyVoice:
self.model.load('{}/llm.pt'.format(model_dir), self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir), '{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir)) '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
'{}/llm.llm.fp16.zip'.format(model_dir))
del configs del configs
def list_avaliable_spks(self): def list_avaliable_spks(self):
spks = list(self.frontend.spk2info.keys()) spks = list(self.frontend.spk2info.keys())
return spks return spks
def inference_sft(self, tts_text, spk_id): def inference_sft(self, tts_text, spk_id, stream=False):
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True): for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_sft(i, spk_id) model_input = self.frontend.frontend_sft(i, spk_id)
model_output = self.model.inference(**model_input) start_time = time.time()
tts_speeches.append(model_output['tts_speech']) logging.info('synthesis text {}'.format(i))
return {'tts_speech': torch.concat(tts_speeches, dim=1)} for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k): def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
prompt_text = self.frontend.text_normalize(prompt_text, split=False) prompt_text = self.frontend.text_normalize(prompt_text, split=False)
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True): for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
model_output = self.model.inference(**model_input) start_time = time.time()
tts_speeches.append(model_output['tts_speech']) logging.info('synthesis text {}'.format(i))
return {'tts_speech': torch.concat(tts_speeches, dim=1)} for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_speech_16k): def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
if self.frontend.instruct is True: if self.frontend.instruct is True:
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir)) raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True): for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
model_output = self.model.inference(**model_input) start_time = time.time()
tts_speeches.append(model_output['tts_speech']) logging.info('synthesis text {}'.format(i))
return {'tts_speech': torch.concat(tts_speeches, dim=1)} for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text): def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
if self.frontend.instruct is False: if self.frontend.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir)) raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False) instruct_text = self.frontend.text_normalize(instruct_text, split=False)
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True): for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
model_output = self.model.inference(**model_input) start_time = time.time()
tts_speeches.append(model_output['tts_speech']) logging.info('synthesis text {}'.format(i))
return {'tts_speech': torch.concat(tts_speeches, dim=1)} for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

View File

@@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import numpy as np
import threading
import time
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
class CosyVoiceModel: class CosyVoiceModel:
@@ -23,38 +30,143 @@ class CosyVoiceModel:
self.llm = llm self.llm = llm
self.flow = flow self.flow = flow
self.hift = hift self.hift = hift
self.token_min_hop_len = 100
self.token_max_hop_len = 200
self.token_overlap_len = 20
# mel fade in out
self.mel_overlap_len = 34
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 20
self.source_cache_len = int(self.mel_cache_len * 256)
# rtf and decoding related
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.mel_overlap_dict = {}
self.hift_cache_dict = {}
def load(self, llm_model, flow_model, hift_model): def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
self.llm.to(self.device).eval() self.llm.to(self.device).eval()
self.llm.half()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
self.flow.to(self.device).eval() self.flow.to(self.device).eval()
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
self.hift.to(self.device).eval() self.hift.to(self.device).eval()
def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192), def load_jit(self, llm_text_encoder_model, llm_llm_model):
prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32), llm_text_encoder = torch.jit.load(llm_text_encoder_model)
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), self.llm.text_encoder = llm_text_encoder
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), llm_llm = torch.jit.load(llm_llm_model)
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)): self.llm.llm = llm_llm
tts_speech_token = self.llm.inference(text=text.to(self.device),
text_len=text_len.to(self.device), def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
prompt_text=prompt_text.to(self.device), with self.llm_context:
prompt_text_len=prompt_text_len.to(self.device), for i in self.llm.inference(text=text.to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device), text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), prompt_text=prompt_text.to(self.device),
embedding=llm_embedding.to(self.device), prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
beam_size=1, prompt_speech_token=llm_prompt_speech_token.to(self.device),
sampling=25, prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
max_token_text_ratio=30, embedding=llm_embedding.to(self.device).half(),
min_token_text_ratio=3) sampling=25,
tts_mel = self.flow.inference(token=tts_speech_token, max_token_text_ratio=30,
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device), min_token_text_ratio=3):
prompt_token=flow_prompt_speech_token.to(self.device), self.tts_speech_token_dict[uuid].append(i)
prompt_token_len=flow_prompt_speech_token_len.to(self.device), self.llm_end_dict[uuid] = True
prompt_feat=prompt_speech_feat.to(self.device),
prompt_feat_len=prompt_speech_feat_len.to(self.device), def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
embedding=flow_embedding.to(self.device)) with self.flow_hift_context:
tts_speech = self.hift.inference(mel=tts_mel).cpu() tts_mel = self.flow.inference(token=token.to(self.device),
torch.cuda.empty_cache() token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
return {'tts_speech': tts_speech} prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device))
# mel overlap fade in out
if self.mel_overlap_dict[uuid] is not None:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
return tts_speech
def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=False)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better speech quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
p.join()
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
torch.cuda.synchronize()

View File

@@ -111,6 +111,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding) embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text # concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding) mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask token = self.input_embedding(torch.clamp(token, min=0)) * mask
@@ -118,17 +119,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
# text encode # text encode
h, h_lengths = self.encoder(token, token_len) h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h) h = self.encoder_proj(h)
feat_len = (token_len / 50 * 22050 / 256).int() mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
h, h_lengths = self.length_regulator(h, feat_len) h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
# get conditions # get conditions
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device) conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
if prompt_feat.shape[1] != 0: conds[:, :mel_len1] = prompt_feat
for i, j in enumerate(prompt_feat_len):
conds[i, :j] = prompt_feat[i]
conds = conds.transpose(1, 2) conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h) # mask = (~make_pad_mask(feat_len)).to(h)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat = self.decoder( feat = self.decoder(
mu=h.transpose(1, 2).contiguous(), mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1), mask=mask.unsqueeze(1),
@@ -136,6 +136,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
cond=conds, cond=conds,
n_timesteps=10 n_timesteps=10
) )
if prompt_feat.shape[1] != 0: feat = feat[:, :, mel_len1:]
feat = feat[:, :, prompt_feat.shape[1]:] assert feat.shape[2] == mel_len2
return feat return feat

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Tuple
import torch.nn as nn import torch.nn as nn
import torch
from torch.nn import functional as F from torch.nn import functional as F
from cosyvoice.utils.mask import make_pad_mask from cosyvoice.utils.mask import make_pad_mask
@@ -43,7 +44,25 @@ class InterpolateRegulator(nn.Module):
def forward(self, x, ylens=None): def forward(self, x, ylens=None):
# x in (B, T, D) # x in (B, T, D)
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
out = self.model(x).transpose(1, 2).contiguous() out = self.model(x).transpose(1, 2).contiguous()
olens = ylens olens = ylens
return out * mask, olens return out * mask, olens
def inference(self, x1, x2, mel_len1, mel_len2):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# x in (B, T, D)
if x2.shape[1] > 40:
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
else:
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
if x1.shape[1] != 0:
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
x = torch.concat([x1, x2], dim=2)
else:
x = x2
out = self.model(x).transpose(1, 2).contiguous()
return out, mel_len1 + mel_len2

View File

@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
return inverse_transform return inverse_transform
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
f0 = self.f0_predictor(x) f0 = self.f0_predictor(x)
s = self._f02source(f0) s = self._f02source(f0)
# use cache_source to avoid glitch
if cache_source.shape[2] == 0:
s[:, :, :cache_source.shape[2]] = cache_source
s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
x = self._istft(magnitude, phase) x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit) x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x return x, s
def remove_weight_norm(self): def remove_weight_norm(self):
print('Removing weight norm...') print('Removing weight norm...')
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
l.remove_weight_norm() l.remove_weight_norm()
@torch.inference_mode() @torch.inference_mode()
def inference(self, mel: torch.Tensor) -> torch.Tensor: def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
return self.forward(x=mel) return self.forward(x=mel, cache_source=cache_source)

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Optional, Union from typing import Dict, Optional, Callable, List, Generator
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
speech_token_size: int, speech_token_size: int,
text_encoder: torch.nn.Module, text_encoder: torch.nn.Module,
llm: torch.nn.Module, llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True, length_normalized_loss: bool = True,
lsm_weight: float = 0.0, lsm_weight: float = 0.0,
spk_embed_dim: int = 192, spk_embed_dim: int = 192,
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size) self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
# 4. sampling method
self.sampling = sampling
def encode( def encode(
self, self,
text: torch.Tensor, text: torch.Tensor,
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
def sampling_ids( def sampling_ids(
self, self,
weighted_scores: torch.Tensor, weighted_scores: torch.Tensor,
sampling: Union[bool, int, float] = True, decoded_tokens: List,
beam_size: int = 1, sampling: int,
ignore_eos: bool = True, ignore_eos: bool = True,
): ):
while True: while True:
prob, indices = weighted_scores.softmax(dim=-1).topk(sampling) top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
top_ids = prob.multinomial(beam_size, replacement=True)
top_ids = indices[top_ids]
if (not ignore_eos) or (self.speech_token_size not in top_ids): if (not ignore_eos) or (self.speech_token_size not in top_ids):
break break
return top_ids return top_ids
@@ -154,11 +156,10 @@ class TransformerLM(torch.nn.Module):
prompt_speech_token: torch.Tensor, prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor, prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor, embedding: torch.Tensor,
beam_size: int = 1,
sampling: int = 25, sampling: int = 25,
max_token_text_ratio: float = 20, max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2, min_token_text_ratio: float = 2,
) -> torch.Tensor: ) -> Generator[torch.Tensor, None, None]:
device = text.device device = text.device
text = torch.concat([prompt_text, text], dim=1) text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len text_len += prompt_text_len
@@ -173,7 +174,7 @@ class TransformerLM(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding) embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(dim=1) embedding = embedding.unsqueeze(dim=1)
else: else:
embedding = torch.zeros(1, 0, self.llm_input_size).to(device) embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
# 3. concat llm_input # 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
@@ -181,7 +182,7 @@ class TransformerLM(torch.nn.Module):
if prompt_speech_token_len != 0: if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else: else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device) prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length # 4. cal min/max_length
@@ -196,11 +197,11 @@ class TransformerLM(torch.nn.Module):
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache, y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool)) att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item() top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size: if top_ids == self.speech_token_size:
break break
# in stream mode, yield token one by one
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
out_tokens.append(top_ids) out_tokens.append(top_ids)
offset += lm_input.size(1) offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
return torch.tensor([out_tokens], dtype=torch.int64, device=device)

View File

@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
torch.nn.init.xavier_uniform_(self.pos_bias_u) torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v) torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x): def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
"""Compute relative positional encoding. """Compute relative positional encoding.
Args: Args:
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
torch.Tensor: Output tensor. torch.Tensor: Output tensor.
""" """
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1) x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[ x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1 :, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2 ] # only keep the positions from 0 to time2

View File

@@ -174,7 +174,7 @@ class TransformerDecoder(torch.nn.Module):
memory_mask) memory_mask)
return x return x
@torch.jit.ignore(drop=True) @torch.jit.unused
def forward_layers_checkpointed(self, x: torch.Tensor, def forward_layers_checkpointed(self, x: torch.Tensor,
tgt_mask: torch.Tensor, tgt_mask: torch.Tensor,
memory: torch.Tensor, memory: torch.Tensor,

View File

@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
""" """
def __init__(self, d_model, dropout_rate, max_len=5000): def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(EspnetRelPositionalEncoding, self).__init__() super(EspnetRelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
@@ -221,7 +221,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x): def extend_pe(self, x: torch.Tensor):
"""Reset the positional encodings.""" """Reset the positional encodings."""
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
@@ -253,7 +253,8 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1) pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0): def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding. """Add positional encoding.
Args: Args:

View File

@@ -169,7 +169,7 @@ class BaseEncoder(torch.nn.Module):
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs return xs
@torch.jit.ignore(drop=True) @torch.jit.unused
def forward_layers_checkpointed(self, xs: torch.Tensor, def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor, chunk_masks: torch.Tensor,
pos_emb: torch.Tensor, pos_emb: torch.Tensor,
@@ -180,6 +180,7 @@ class BaseEncoder(torch.nn.Module):
mask_pad) mask_pad)
return xs return xs
@torch.jit.export
def forward_chunk( def forward_chunk(
self, self,
xs: torch.Tensor, xs: torch.Tensor,
@@ -270,6 +271,7 @@ class BaseEncoder(torch.nn.Module):
return (xs, r_att_cache, r_cnn_cache) return (xs, r_att_cache, r_cnn_cache)
@torch.jit.unused
def forward_chunk_by_chunk( def forward_chunk_by_chunk(
self, self,
xs: torch.Tensor, xs: torch.Tensor,

View File

@@ -101,3 +101,39 @@ def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__ classname = m.__class__.__name__
if classname.find("Conv") != -1: if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std) m.weight.data.normal_(mean, std)
# Repetition Aware Sampling in VALL-E 2
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r:
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
return top_ids
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
prob, indices = [], []
cum_prob = 0.0
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
for i in range(len(sorted_idx)):
# sampling both top-p and numbers.
if cum_prob < top_p and len(prob) < top_k:
cum_prob += sorted_value[i]
prob.append(sorted_value[i])
indices.append(sorted_idx[i])
else:
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
return top_ids
def random_sampling(weighted_scores, decoded_tokens, sampling):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
return top_ids
def fade_in_out(fade_in_mel, fade_out_mel, window):
device = fade_in_mel.device
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel.to(device)

View File

@@ -15,6 +15,10 @@
import json import json
import torchaudio import torchaudio
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
def read_lists(list_file): def read_lists(list_file):

View File

@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
num_blocks: 3 num_blocks: 3
dropout_rate: 0.1 dropout_rate: 0.1
positional_dropout_rate: 0.1 positional_dropout_rate: 0.1
attention_dropout_rate: 0 attention_dropout_rate: 0.0
normalize_before: True normalize_before: True
input_layer: 'linear' input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet' pos_enc_layer_type: 'rel_pos_espnet'
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
num_blocks: 7 num_blocks: 7
dropout_rate: 0.1 dropout_rate: 0.1
positional_dropout_rate: 0.1 positional_dropout_rate: 0.1
attention_dropout_rate: 0 attention_dropout_rate: 0.0
input_layer: 'linear_legacy' input_layer: 'linear_legacy'
pos_enc_layer_type: 'rel_pos_espnet' pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn' selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1 static_chunk_size: 1
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512 input_size: 512
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
in_channels: 320 in_channels: 320
out_channels: 80 out_channels: 80
channels: [256, 256] channels: [256, 256]
dropout: 0 dropout: 0.0
attention_head_dim: 64 attention_head_dim: 64
n_blocks: 4 n_blocks: 4
num_mid_blocks: 8 num_mid_blocks: 8

View File

@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
num_blocks: 6 num_blocks: 6
dropout_rate: 0.1 dropout_rate: 0.1
positional_dropout_rate: 0.1 positional_dropout_rate: 0.1
attention_dropout_rate: 0 attention_dropout_rate: 0.0
normalize_before: True normalize_before: True
input_layer: 'linear' input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet' pos_enc_layer_type: 'rel_pos_espnet'
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
num_blocks: 14 num_blocks: 14
dropout_rate: 0.1 dropout_rate: 0.1
positional_dropout_rate: 0.1 positional_dropout_rate: 0.1
attention_dropout_rate: 0 attention_dropout_rate: 0.0
input_layer: 'linear_legacy' input_layer: 'linear_legacy'
pos_enc_layer_type: 'rel_pos_espnet' pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn' selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1 static_chunk_size: 1
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512 input_size: 512
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
in_channels: 320 in_channels: 320
out_channels: 80 out_channels: 80
channels: [256, 256] channels: [256, 256]
dropout: 0 dropout: 0.0
attention_head_dim: 64 attention_head_dim: 64
n_blocks: 4 n_blocks: 4
num_mid_blocks: 12 num_mid_blocks: 12

View File

@@ -61,8 +61,11 @@ def main():
request.instruct_request.CopyFrom(instruct_request) request.instruct_request.CopyFrom(instruct_request)
response = stub.Inference(request) response = stub.Inference(request)
tts_audio = b''
for r in response:
tts_audio += r.tts_audio
tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
logging.info('save response to {}'.format(args.tts_wav)) logging.info('save response to {}'.format(args.tts_wav))
tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
torchaudio.save(args.tts_wav, tts_speech, target_sr) torchaudio.save(args.tts_wav, tts_speech, target_sr)
logging.info('get response') logging.info('get response')

View File

@@ -4,7 +4,7 @@ package cosyvoice;
option go_package = "protos/"; option go_package = "protos/";
service CosyVoice{ service CosyVoice{
rpc Inference(Request) returns (Response) {} rpc Inference(Request) returns (stream Response) {}
} }
message Request{ message Request{

View File

@@ -54,9 +54,10 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text) model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
logging.info('send inference response') logging.info('send inference response')
response = cosyvoice_pb2.Response() for i in model_output:
response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() response = cosyvoice_pb2.Response()
return response response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
yield response
def main(): def main():
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc) grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)

View File

@@ -24,14 +24,8 @@ import torchaudio
import random import random
import librosa import librosa
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav, speed_change from cosyvoice.utils.file_utils import load_wav, speed_change, logging
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
def generate_seed(): def generate_seed():
seed = random.randint(1, 100000000) seed = random.randint(1, 100000000)
@@ -63,10 +57,11 @@ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成
'3s极速复刻': '1. 选择prompt音频文件或录入prompt音频注意不超过30s若同时提供优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮', '3s极速复刻': '1. 选择prompt音频文件或录入prompt音频注意不超过30s若同时提供优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
'跨语种复刻': '1. 选择prompt音频文件或录入prompt音频注意不超过30s若同时提供优先选择prompt音频文件\n2. 点击生成音频按钮', '跨语种复刻': '1. 选择prompt音频文件或录入prompt音频注意不超过30s若同时提供优先选择prompt音频文件\n2. 点击生成音频按钮',
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'} '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
stream_mode_list = [('', False), ('', True)]
def change_instruction(mode_checkbox_group): def change_instruction(mode_checkbox_group):
return instruct_dict[mode_checkbox_group] return instruct_dict[mode_checkbox_group]
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor): def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor):
if prompt_wav_upload is not None: if prompt_wav_upload is not None:
prompt_wav = prompt_wav_upload prompt_wav = prompt_wav_upload
elif prompt_wav_record is not None: elif prompt_wav_record is not None:
@@ -117,32 +112,25 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
if mode_checkbox_group == '预训练音色': if mode_checkbox_group == '预训练音色':
logging.info('get sft inference request') logging.info('get sft inference request')
set_all_random_seed(seed) set_all_random_seed(seed)
output = cosyvoice.inference_sft(tts_text, sft_dropdown) for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
yield (target_sr, i['tts_speech'].numpy().flatten())
elif mode_checkbox_group == '3s极速复刻': elif mode_checkbox_group == '3s极速复刻':
logging.info('get zero_shot inference request') logging.info('get zero_shot inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed) set_all_random_seed(seed)
output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k) for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
yield (target_sr, i['tts_speech'].numpy().flatten())
elif mode_checkbox_group == '跨语种复刻': elif mode_checkbox_group == '跨语种复刻':
logging.info('get cross_lingual inference request') logging.info('get cross_lingual inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed) set_all_random_seed(seed)
output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
yield (target_sr, i['tts_speech'].numpy().flatten())
else: else:
logging.info('get instruct inference request') logging.info('get instruct inference request')
set_all_random_seed(seed) set_all_random_seed(seed)
output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text) for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
yield (target_sr, i['tts_speech'].numpy().flatten())
if speed_factor != 1.0:
try:
audio_data, sample_rate = speed_change(output["tts_speech"], target_sr, str(speed_factor))
audio_data = audio_data.numpy().flatten()
except Exception as e:
print(f"Failed to change speed of audio: \n{e}")
else:
audio_data = output['tts_speech'].numpy().flatten()
return (target_sr, audio_data)
def main(): def main():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@@ -155,6 +143,7 @@ def main():
mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0]) mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5) instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25) sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
with gr.Column(scale=0.25): with gr.Column(scale=0.25):
seed_button = gr.Button(value="\U0001F3B2") seed_button = gr.Button(value="\U0001F3B2")
seed = gr.Number(value=0, label="随机推理种子") seed = gr.Number(value=0, label="随机推理种子")
@@ -167,11 +156,11 @@ def main():
generate_button = gr.Button("生成音频") generate_button = gr.Button("生成音频")
audio_output = gr.Audio(label="合成音频") audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
seed_button.click(generate_seed, inputs=[], outputs=seed) seed_button.click(generate_seed, inputs=[], outputs=seed)
generate_button.click(generate_audio, generate_button.click(generate_audio,
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor], inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor],
outputs=[audio_output]) outputs=[audio_output])
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text]) mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
demo.queue(max_size=4, default_concurrency_limit=2) demo.queue(max_size=4, default_concurrency_limit=2)
@@ -184,7 +173,7 @@ if __name__ == '__main__':
default=8000) default=8000)
parser.add_argument('--model_dir', parser.add_argument('--model_dir',
type=str, type=str,
default='iic/CosyVoice-300M', default='pretrained_models/CosyVoice-300M',
help='local path or modelscope repo id') help='local path or modelscope repo id')
args = parser.parse_args() args = parser.parse_args()
cosyvoice = CosyVoice(args.model_dir) cosyvoice = CosyVoice(args.model_dir)