Merge branch 'main' into main

This commit is contained in:
Xiang Lyu
2025-12-31 10:37:18 +08:00
committed by GitHub
44 changed files with 2476 additions and 591 deletions

View File

@@ -23,7 +23,7 @@ import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
@@ -57,15 +57,9 @@ def main():
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
try:
model = CosyVoice(args.model_dir)
except Exception:
try:
model = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
model = AutoModel(model_dir=args.model_dir)
if not isinstance(model, CosyVoice2):
if model.__class__.__name__ == 'CosyVoice':
# 1. export llm text_encoder
llm_text_encoder = model.model.llm.text_encoder
script = get_optimized_script(llm_text_encoder)
@@ -89,14 +83,16 @@ def main():
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else:
# 3. export flow encoder
elif model.__class__.__name__ == 'CosyVoice2':
# 1. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else:
raise ValueError('unsupported model type')
if __name__ == '__main__':

View File

@@ -27,7 +27,7 @@ from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
@@ -58,13 +58,7 @@ def main():
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
try:
model = CosyVoice(args.model_dir)
except Exception:
try:
model = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
model = AutoModel(model_dir=args.model_dir)
# 1. export flow decoder estimator
estimator = model.model.flow.decoder.estimator

View File

@@ -1,126 +0,0 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.dataset.dataset import Dataset
def get_args():
parser = argparse.ArgumentParser(description='inference with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
parser.add_argument('--tts_text', required=True, help='tts input file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--llm_model', required=True, help='llm model file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot'],
help='inference mode')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
except Exception:
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
except Exception:
raise TypeError('no valid model_type!')
model.load(args.llm_model, args.flow_model, args.hifigan_model)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
sample_rate = configs['sample_rate']
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
tts_index = batch["tts_index"]
tts_text_token = batch["tts_text_token"].to(device)
tts_text_token_len = batch["tts_text_token_len"].to(device)
speech_token = batch["speech_token"].to(device)
speech_token_len = batch["speech_token_len"].to(device)
speech_feat = batch["speech_feat"].to(device)
speech_feat_len = batch["speech_feat_len"].to(device)
utt_embedding = batch["utt_embedding"].to(device)
spk_embedding = batch["spk_embedding"].to(device)
if args.mode == 'sft':
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
else:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': text_token, 'prompt_text_len': text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
tts_speeches = []
for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
main()

View File

@@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
@@ -27,7 +27,6 @@ from cosyvoice.utils.class_utils import get_model_type
class CosyVoice:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
@@ -37,7 +36,7 @@ class CosyVoice:
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f)
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
@@ -67,9 +66,9 @@ class CosyVoice:
spks = list(self.frontend.spk2info.keys())
return spks
def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
del model_input['text']
del model_input['text_len']
self.frontend.spk2info[zero_shot_spk_id] = model_input
@@ -89,12 +88,14 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -103,9 +104,9 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -115,9 +116,7 @@ class CosyVoice:
start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
if self.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
@@ -129,8 +128,8 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
start_time = time.time()
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
@@ -142,7 +141,6 @@ class CosyVoice:
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
@@ -160,9 +158,9 @@ class CosyVoice2(CosyVoice):
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
@@ -178,13 +176,9 @@ class CosyVoice2(CosyVoice):
self.fp16)
del configs
def inference_instruct(self, *args, **kwargs):
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -192,3 +186,55 @@ class CosyVoice2(CosyVoice):
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
class CosyVoice3(CosyVoice2):
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v3.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
load_trt, fp16 = False, False
logging.warning('no cuda device, set load_trt/fp16 to False')
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_vllm:
self.model.load_vllm('{}/vllm'.format(model_dir))
if load_trt:
if self.fp16 is True:
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
del configs
def AutoModel(**kwargs):
if not os.path.exists(kwargs['model_dir']):
kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
return CosyVoice(**kwargs)
elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
return CosyVoice2(**kwargs)
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
return CosyVoice3(**kwargs)
else:
raise TypeError('No valid model type found!')

View File

@@ -20,19 +20,10 @@ import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import os
import re
import inflect
try:
import ttsfrd
use_ttsfrd = True
except ImportError:
print("failed to import ttsfrd, use wetext instead")
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
use_ttsfrd = False
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
@@ -60,17 +51,29 @@ class CosyVoiceFrontEnd:
else:
self.spk2info = {}
self.allowed_special = allowed_special
self.use_ttsfrd = use_ttsfrd
if self.use_ttsfrd:
self.inflect_parser = inflect.engine()
# NOTE compatible when no text frontend tool is avaliable
try:
import ttsfrd
self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg')
else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
self.text_frontend = 'ttsfrd'
logging.info('use ttsfrd frontend')
except:
try:
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
self.en_tn_model = EnNormalizer()
self.text_frontend = 'wetext'
logging.info('use wetext frontend')
except:
self.text_frontend = ''
logging.info('no frontend is avaliable')
def _extract_text_token(self, text):
if isinstance(text, Generator):
@@ -89,7 +92,8 @@ class CosyVoiceFrontEnd:
for i in range(text_token.shape[1]):
yield text_token[:, i: i + 1]
def _extract_speech_token(self, speech):
def _extract_speech_token(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(None,
@@ -101,7 +105,8 @@ class CosyVoiceFrontEnd:
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
def _extract_spk_embedding(self, speech):
def _extract_spk_embedding(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
@@ -112,7 +117,8 @@ class CosyVoiceFrontEnd:
embedding = torch.tensor([embedding]).to(self.device)
return embedding
def _extract_speech_feat(self, speech):
def _extract_speech_feat(self, prompt_wav):
speech = load_wav(prompt_wav, 24000)
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
@@ -122,15 +128,19 @@ class CosyVoiceFrontEnd:
if isinstance(text, Generator):
logging.info('get tts_text generator, will skip text_normalize!')
return [text]
# NOTE skip text_frontend when ssml symbol in text
if '<|' in text and '|>' in text:
text_frontend = False
if text_frontend is False or text == '':
return [text] if split is True else text
text = text.strip()
if self.use_ttsfrd:
if self.text_frontend == 'ttsfrd':
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else:
if contains_chinese(text):
text = self.zh_tn_model.normalize(text)
if self.text_frontend == 'wetext':
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "")
text = replace_blank(text)
text = replace_corner_mark(text)
@@ -141,7 +151,8 @@ class CosyVoiceFrontEnd:
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
else:
text = self.en_tn_model.normalize(text)
if self.text_frontend == 'wetext':
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
@@ -154,32 +165,31 @@ class CosyVoiceFrontEnd:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
if zero_shot_spk_id == '':
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
embedding = self._extract_spk_embedding(prompt_speech_16k)
embedding = self._extract_spk_embedding(prompt_wav)
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
else:
model_input = self.spk2info[zero_shot_spk_id]
model_input = {**self.spk2info[zero_shot_spk_id]}
model_input['text'] = tts_text_token
model_input['text_len'] = tts_text_token_len
return model_input
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
# in cross lingual mode, we remove prompt in llm
del model_input['prompt_text']
del model_input['prompt_text_len']
@@ -191,22 +201,21 @@ class CosyVoiceFrontEnd:
model_input = self.frontend_sft(tts_text, spk_id)
# in instruct mode, we remove spk_embedding in llm due to information leakage
del model_input['llm_embedding']
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
model_input['prompt_text'] = instruct_text_token
model_input['prompt_text_len'] = instruct_text_token_len
return model_input
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
embedding = self._extract_spk_embedding(prompt_speech_16k)
def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
embedding = self._extract_spk_embedding(prompt_wav)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,

View File

@@ -38,9 +38,6 @@ class CosyVoiceModel:
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20
@@ -63,6 +60,7 @@ class CosyVoiceModel:
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -101,26 +99,33 @@ class CosyVoiceModel:
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
cur_silent_token_num, max_silent_token_num = 0, 5
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator):
assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
for i in self.llm.inference_bistream(text=text,
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
token_generator = self.llm.inference_bistream(text=text,
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device))
else:
token_generator = self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)):
self.tts_speech_token_dict[uuid].append(i)
else:
for i in self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device),
uuid=uuid):
self.tts_speech_token_dict[uuid].append(i)
embedding=llm_embedding.to(self.device),
uuid=uuid)
for i in token_generator:
if i in self.silent_tokens:
cur_silent_token_num += 1
if cur_silent_token_num > max_silent_token_num:
continue
else:
cur_silent_token_num = 0
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
def vc_job(self, source_speech_token, uuid):
@@ -129,7 +134,7 @@ class CosyVoiceModel:
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
@@ -249,9 +254,6 @@ class CosyVoice2Model(CosyVoiceModel):
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# hift cache
@@ -266,6 +268,7 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
@@ -284,7 +287,7 @@ class CosyVoice2Model(CosyVoiceModel):
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
@@ -384,3 +387,55 @@ class CosyVoice2Model(CosyVoiceModel):
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
class CosyVoice3Model(CosyVoice2Model):
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
# FSQ silent and breath token
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append mel cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
self.hift_cache_dict[uuid]['mel'] = tts_mel
else:
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
if speed != 1.0:
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
return tts_speech

View File

@@ -145,7 +145,11 @@ def Dataset(data_list_file,
shuffle=shuffle,
partition=partition)
# map partial arg to padding func
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
for i in range(1, len(data_pipeline)):
if data_pipeline[i].func.__name__ == 'compute_fbank':
data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
if data_pipeline[i].func.__name__ == 'padding':
data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode)
return dataset

View File

@@ -26,7 +26,7 @@ import pyworld as pw
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def parquet_opener(data, mode='train', tts_data={}):
def parquet_opener(data, mode='train'):
""" Give url or local file, return file descriptor
Inplace operation.
@@ -44,12 +44,8 @@ def parquet_opener(data, mode='train', tts_data={}):
df = df.to_pandas()
for i in range(len(df)):
sample.update(dict(df.loc[i]))
if mode == 'train':
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
@@ -242,6 +238,10 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if 'instruct' in sample:
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
else:
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
yield sample
@@ -390,6 +390,9 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
@@ -403,6 +406,8 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"instruct_token": instruct_token,
"instruct_token_len": instruct_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}

176
cosyvoice/flow/DiT/dit.py Normal file
View File

@@ -0,0 +1,176 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from x_transformers.x_transformers import RotaryEmbedding
from cosyvoice.utils.mask import add_optional_chunk_mask
from cosyvoice.flow.DiT.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
CausalConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis,
get_pos_embed_indices,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
batch, text_len = text.shape[0], text.shape[1]
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text_len), value=0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
super().__init__()
spk_dim = 0 if spk_dim is None else spk_dim
self.spk_dim = spk_dim
self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
text_embed: float["b n d"],
spks: float["b d"],
):
to_cat = [x, cond, text_embed]
if self.spk_dim > 0:
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
to_cat.append(spks)
x = self.proj(torch.cat(to_cat, dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=80,
mu_dim=None,
long_skip_connection=False,
spk_dim=None,
out_channels=None,
static_chunk_size=50,
num_decoding_left_chunks=2
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if mu_dim is None:
mu_dim = mel_dim
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.out_channels = out_channels
self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
x = x.transpose(1, 2)
mu = mu.transpose(1, 2)
cond = cond.transpose(1, 2)
spks = spks.unsqueeze(dim=1)
batch, seq_len = x.shape[0], x.shape[1]
if t.ndim == 0:
t = t.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(t)
x = self.input_embed(x, cond, mu, spks.squeeze(1))
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
else:
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
for block in self.transformer_blocks:
x = block(x, t, mask=attn_mask.bool(), rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x).transpose(1, 2)
return output

View File

@@ -0,0 +1,616 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Optional
import math
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
class MelSpec(nn.Module):
def __init__(
self,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
normalize=False,
power=1,
norm=None,
center=True,
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=filter_length,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=power,
center=center,
normalized=normalize,
norm=norm,
)
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, inp):
if len(inp.shape) == 3:
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
assert len(inp.shape) == 2
if self.dummy.device != inp.device:
self.to(inp.device)
mel = self.mel_stft(inp)
mel = mel.clamp(min=1e-5).log()
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
class CausalConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.kernel_size = kernel_size
self.conv1 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
)
self.conv2 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv1(x)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv2(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
# rotary positional embedding related
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
pos = (
start.unsqueeze(1)
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if self.context_dim is not None:
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
else:
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(-1)
else:
mask = mask[:, 0, -1].unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
batch_size = c.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, : residual.shape[1]],
x[:, residual.shape[1]:],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(ff_norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
super().__init__()
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_pre_only=context_pre_only,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
# attention
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time

View File

@@ -37,14 +37,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
@@ -165,14 +162,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
@@ -279,3 +273,160 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), None
class CausalMaskedDiffWithDiT(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
pre_lookahead_layer: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.pre_lookahead_len = pre_lookahead_len
self.pre_lookahead_layer = pre_lookahead_layer
self.decoder = decoder
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h = self.pre_lookahead_layer(token)
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
streaming=streaming,
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
streaming,
finalize):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
if finalize is True:
h = self.pre_lookahead_layer(token)
else:
h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), None
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from hyperpyyaml import load_hyperpyyaml
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
model = configs['flow']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
max_len = 10 * model.decoder.estimator.static_chunk_size
chunk_size = model.decoder.estimator.static_chunk_size
context_size = model.pre_lookahead_layer.pre_lookahead_len
token = torch.randint(0, 6561, size=(1, max_len)).to(device)
token_len = torch.tensor([max_len]).to(device)
prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
prompt_token_len = torch.tensor([chunk_size]).to(device)
prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
prompt_embedding = torch.rand(1, 192).to(device)
pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
for i in range(0, max_len, chunk_size):
finalize = True if i + chunk_size + context_size >= max_len else False
pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())

View File

@@ -91,12 +91,13 @@ class ConditionalCFM(BASECFM):
sol = []
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x

View File

@@ -17,6 +17,7 @@ try:
from torch.nn.utils.parametrizations import weight_norm
except ImportError:
from torch.nn.utils import weight_norm
from cosyvoice.transformer.convolution import CausalConv1d
class ConvRNNF0Predictor(nn.Module):
@@ -56,3 +57,47 @@ class ConvRNNF0Predictor(nn.Module):
x = self.condnet(x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))
class CausalConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512
):
super().__init__()
self.num_class = num_class
self.condnet = nn.Sequential(
weight_norm(
CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
if finalize is True:
x = self.condnet[0](x)
else:
x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
for i in range(1, len(self.condnet)):
x = self.condnet[i](x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))

View File

@@ -28,7 +28,7 @@ try:
except ImportError:
from torch.nn.utils import weight_norm
from torch.distributions.uniform import Uniform
from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
from cosyvoice.transformer.activation import Snake
from cosyvoice.utils.common import get_padding
from cosyvoice.utils.common import init_weights
@@ -50,8 +50,10 @@ class ResBlock(torch.nn.Module):
channels: int = 512,
kernel_size: int = 3,
dilations: List[int] = [1, 3, 5],
causal: bool = False,
):
super(ResBlock, self).__init__()
self.causal = causal
self.convs1 = nn.ModuleList()
self.convs2 = nn.ModuleList()
@@ -64,7 +66,14 @@ class ResBlock(torch.nn.Module):
kernel_size,
1,
dilation=dilation,
padding=get_padding(kernel_size, dilation)
padding=get_padding(kernel_size, dilation)) if causal is False else
CausalConv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
causal_type='left'
)
)
)
@@ -76,7 +85,14 @@ class ResBlock(torch.nn.Module):
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)
padding=get_padding(kernel_size, 1)) if causal is False else
CausalConv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
causal_type='left'
)
)
)
@@ -139,11 +155,13 @@ class SineGen(torch.nn.Module):
@torch.no_grad()
def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, dim=1, length)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
:param f0: [B, 1, sample_len], Hz
:return: [B, 1, sample_len]
"""
f0 = f0.transpose(1, 2)
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1):
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
@@ -168,59 +186,7 @@ class SineGen(torch.nn.Module):
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
sine_wavs = sine_wavs.transpose(1, 2)
uv = uv.transpose(1, 2)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
class SineGen2(torch.nn.Module):
@@ -242,7 +208,8 @@ class SineGen2(torch.nn.Module):
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0,
flag_for_pulse=False):
flag_for_pulse=False,
causal=False):
super(SineGen2, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
@@ -252,6 +219,11 @@ class SineGen2(torch.nn.Module):
self.voiced_threshold = voiced_threshold
self.flag_for_pulse = flag_for_pulse
self.upsample_scale = upsample_scale
self.causal = causal
if causal is True:
self.rand_ini = torch.rand(1, 9)
self.rand_ini[:, 0] = 0
self.sine_waves = torch.rand(1, 300 * 24000, 9)
def _f02uv(self, f0):
# generate uv signal
@@ -267,9 +239,12 @@ class SineGen2(torch.nn.Module):
rad_values = (f0_values / self.sampling_rate) % 1
# initial phase noise (no noise for fundamental component)
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
if self.training is False and self.causal is True:
rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
else:
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
if not self.flag_for_pulse:
@@ -279,7 +254,7 @@ class SineGen2(torch.nn.Module):
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
sines = torch.sin(phase)
else:
# If necessary, make sure that the first time step of every
@@ -331,7 +306,10 @@ class SineGen2(torch.nn.Module):
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
if self.training is False and self.causal is True:
noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
else:
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
@@ -339,7 +317,7 @@ class SineGen2(torch.nn.Module):
return sine_waves, uv, noise
class SourceModuleHnNSF2(torch.nn.Module):
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
@@ -358,19 +336,24 @@ class SourceModuleHnNSF2(torch.nn.Module):
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF2, self).__init__()
add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
if sinegen_type == '1':
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
else:
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
self.causal = causal
if causal is True:
self.uv = torch.rand(1, 300 * 24000, 1)
def forward(self, x):
"""
@@ -385,7 +368,10 @@ class SourceModuleHnNSF2(torch.nn.Module):
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
if self.training is False and self.causal is True:
noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
else:
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
@@ -425,15 +411,16 @@ class HiFTGenerator(nn.Module):
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
# NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
self.m_source = this_SourceModuleHnNSF(
# NOTE in CosyVoice2, we use the original SineGen implementation
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold)
voiced_threshod=nsf_voiced_threshold,
sinegen_type='1' if self.sampling_rate == 22050 else '2',
causal=False)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
@@ -580,3 +567,180 @@ class HiFTGenerator(nn.Module):
s[:, :, :cache_source.shape[2]] = cache_source
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, s
class CausalHiFTGenerator(HiFTGenerator):
"""
HiFTNet Generator: Neural Source Filter + ISTFTNet
https://arxiv.org/abs/2309.09493
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
nb_harmonics: int = 8,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: List[int] = [8, 8],
upsample_kernel_sizes: List[int] = [16, 16],
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: List[int] = [7, 11],
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
lrelu_slope: float = 0.1,
audio_limit: float = 0.99,
conv_pre_look_right: int = 4,
f0_predictor: torch.nn.Module = None,
):
torch.nn.Module.__init__(self)
self.out_channels = 1
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.istft_params = istft_params
self.lrelu_slope = lrelu_slope
self.audio_limit = audio_limit
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold,
sinegen_type='1' if self.sampling_rate == 22050 else '2',
causal=True)
self.upsample_rates = upsample_rates
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
CausalConv1dUpsample(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
)
)
)
# Down
self.source_downs = nn.ModuleList()
self.source_resblocks = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
if u == 1:
self.source_downs.append(
CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
)
else:
self.source_downs.append(
CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
)
self.source_resblocks.append(
ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d, causal=True))
self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.conv_pre_look_right = conv_pre_look_right
self.f0_predictor = f0_predictor
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
if finalize is True:
x = self.conv_pre(x)
else:
x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
# fusion
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
x = self._istft(magnitude, phase)
if finalize is False:
x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
@torch.inference_mode()
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
self.f0_predictor.to('cpu')
f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
if finalize is True:
generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
else:
generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
return generated_speech, s
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from hyperpyyaml import load_hyperpyyaml
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
model = configs['hift']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
max_len, chunk_size, context_size = 300, 30, 8
mel = torch.rand(1, 80, max_len).to(device)
pred_gt, _ = model.inference(mel)
for i in range(0, max_len, chunk_size):
finalize = True if i + chunk_size + context_size >= max_len else False
pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
pred_chunk = pred_chunk[:, i * 480:]
print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())

View File

@@ -17,6 +17,7 @@ import random
import time
import threading
from typing import Dict, Optional, Callable, List, Generator
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
@@ -56,8 +57,9 @@ class TransformerLM(torch.nn.Module):
)
# 2. build speech token language model related modules
self.sos_eos = 0
self.sos = 0
self.task_id = 1
self.eos_token = self.speech_token_size
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
@@ -85,10 +87,10 @@ class TransformerLM(torch.nn.Module):
encoder_out = self.text_encoder_affine_layer(encoder_out)
return encoder_out, encoder_out_lens
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
@@ -126,15 +128,15 @@ class TransformerLM(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(1)
# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 4. encode speech_token
speech_token = self.speech_embedding(speech_token)
# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len)
# 6. run lm forward
@@ -154,7 +156,7 @@ class TransformerLM(torch.nn.Module):
num_trials, max_trials = 0, 100
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
if (not ignore_eos) or (top_ids < self.speech_token_size):
break
num_trials += 1
if num_trials > max_trials:
@@ -193,13 +195,13 @@ class TransformerLM(torch.nn.Module):
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -215,11 +217,8 @@ class TransformerLM(torch.nn.Module):
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
# force continue decode first token
if i == 0:
logp[:, self.speech_token_size] = -float('inf')
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
if top_ids == self.eos_token:
break
# in stream mode, yield token one by one
yield top_ids
@@ -276,9 +275,10 @@ class Qwen2LM(TransformerLM):
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos_eos = 0
self.sos = 0
self.task_id = 1
self.fill_token = 2
self.eos_token = speech_token_size
self.fill_token = speech_token_size + 2
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
@@ -301,18 +301,23 @@ class Qwen2LM(TransformerLM):
self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {}
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
lm_target, lm_input = [], []
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
# NOTE add instruct_token in CosyVoice3
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
for i in range(len(text_token)):
# bistream sequence
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [], []
this_lm_target.append(IGNORE_ID)
this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
this_lm_input.append(instruct_token_emb[i])
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -320,22 +325,21 @@ class Qwen2LM(TransformerLM):
assert len(this_speech_token) == self.mix_ratio[1]
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
this_lm_target += this_speech_token
this_lm_target.append(self.speech_token_size + 2)
this_lm_target.append(self.fill_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
else:
this_lm_target += [-1] * len(this_text_token)
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
this_lm_target.append(self.speech_token_size)
this_lm_target.append(self.eos_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
this_lm_input.append(task_id_emb.squeeze(dim=0))
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
# unistream sequence
else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target)
lm_input.append(this_lm_input)
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
@@ -363,11 +367,16 @@ class Qwen2LM(TransformerLM):
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
speech_token, speech_token_emb, speech_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward
@@ -392,6 +401,10 @@ class Qwen2LM(TransformerLM):
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
@@ -401,8 +414,8 @@ class Qwen2LM(TransformerLM):
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
lm_target = lm_target.to(device)
# 4. run lm forward
@@ -445,13 +458,13 @@ class Qwen2LM(TransformerLM):
text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -500,11 +513,9 @@ class Qwen2LM(TransformerLM):
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
if top_ids in self.stop_token_ids:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
@@ -526,20 +537,20 @@ class Qwen2LM(TransformerLM):
device = prompt_text.device
# 1. prepare input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb], dim=1)
lm_input = torch.concat([sos_emb], dim=1)
# 2. iterate text
out_tokens = []
cache = None
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
text_cache = self.llm.model.model.embed_tokens(prompt_text)
next_fill_index = -1
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
for this_text in text:
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
# prompt_speech_token_emb not empty, try append to lm_input
@@ -554,12 +565,12 @@ class Qwen2LM(TransformerLM):
break
# no prompt_speech_token_emb remain, can decode some speech token
if prompt_speech_token_emb.size(1) == 0:
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
logging.info('get fill token, need to append more text token')
if text_cache.size(1) >= self.mix_ratio[0]:
lm_input_text = text_cache[:, :self.mix_ratio[0]]
logging.info('append {} text token'.format(lm_input_text.size(1)))
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
lm_input = lm_input_text
else:
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
@@ -574,16 +585,16 @@ class Qwen2LM(TransformerLM):
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2
top_ids = self.fill_token
next_fill_index += (self.mix_ratio[1] + 1)
else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
if top_ids == self.speech_token_size + 2:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
if top_ids == self.fill_token:
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size + 2:
if top_ids == self.fill_token:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
@@ -599,13 +610,136 @@ class Qwen2LM(TransformerLM):
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size:
if top_ids == self.eos_token:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
# in stream mode, yield token one by one
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
class CosyVoice3LM(Qwen2LM):
def __init__(
self,
llm_input_size: int,
llm_output_size: int,
speech_token_size: int,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
mix_ratio: List[int] = [5, 15],
):
torch.nn.Module.__init__(self)
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos = speech_token_size + 0
self.eos_token = speech_token_size + 1
self.task_id = speech_token_size + 2
self.fill_token = speech_token_size + 3
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 200,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
# 4. sampling method
self.sampling = sampling
self.mix_ratio = mix_ratio
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(200)]
self.vllm_output_queue = {}
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
# 3. sos and task_id
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target.to(device))
acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
yield token

View File

@@ -238,7 +238,7 @@ def get_tokenizer(
)
class QwenTokenizer():
class CosyVoice2Tokenizer():
def __init__(self, token_path, skip_special_tokens=True):
super().__init__()
# NOTE: non-chat model, all these special tokens keep randomly initialized.
@@ -271,9 +271,57 @@ class QwenTokenizer():
return text
class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
def __init__(self, token_path, skip_special_tokens=True):
# NOTE: non-chat model, all these special tokens keep randomly initialized.
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]", "<|endofsystem|>",
"[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
"[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
"[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
"[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
"[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
"[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
"[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
"[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
"[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
"[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
"[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
"[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
"[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
"[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
"[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
"[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
"[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
"[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
"[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
"[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
]
}
self.special_tokens = special_tokens
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
self.tokenizer.add_special_tokens(special_tokens)
self.skip_special_tokens = skip_special_tokens
@lru_cache(maxsize=None)
def get_qwen_tokenizer(
token_path: str,
skip_special_tokens: bool
) -> QwenTokenizer:
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
skip_special_tokens: bool,
version: str = 'cosyvoice2'
):
if version == 'cosyvoice2':
return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
elif version == 'cosyvoice3':
return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
else:
raise ValueError

View File

@@ -19,6 +19,7 @@ from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
class ConvolutionModule(nn.Module):
@@ -143,3 +144,115 @@ class ConvolutionModule(nn.Module):
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
# NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
causal_type: str = 'left',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride=1,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
assert causal_type in ['left', 'right']
self.causal_type = causal_type
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
input_timestep = x.shape[2]
if cache.size(2) == 0:
cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
assert cache.size(2) == self.causal_padding
if self.causal_type == 'left':
x = torch.concat([cache, x], dim=2)
else:
x = torch.concat([x, cache], dim=2)
x = super(CausalConv1d, self).forward(x)
assert x.shape[2] == input_timestep
return x
class CausalConv1dDownSample(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride != 1 and dilation == 1
assert kernel_size % stride == 0
self.causal_padding = stride - 1
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
if cache.size(2) == 0:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
else:
assert cache.size(2) == self.causal_padding
x = torch.concat([cache, x], dim=2)
x = super(CausalConv1dDownSample, self).forward(x)
return x
class CausalConv1dUpsample(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
kernel_size, 1,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert dilation == 1
self.causal_padding = kernel_size - 1
self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.upsample(x)
input_timestep = x.shape[2]
if cache.size(2) == 0:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
else:
assert cache.size(2) == self.causal_padding
x = torch.concat([cache, x], dim=2)
x = super(CausalConv1dUpsample, self).forward(x)
assert input_timestep == x.shape[2]
return x

View File

@@ -64,17 +64,18 @@ class Upsample1D(nn.Module):
class PreLookaheadLayer(nn.Module):
def __init__(self, channels: int, pre_lookahead_len: int = 1):
def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
in_channels, channels,
kernel_size=pre_lookahead_len + 1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
channels, channels,
channels, in_channels,
kernel_size=3, stride=1, padding=0,
)
@@ -199,7 +200,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,

View File

@@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
from cosyvoice.hifigan.generator import HiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
COSYVOICE_ACTIVATION_CLASSES = {
@@ -80,4 +80,6 @@ def get_model_type(configs):
return CosyVoiceModel
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoice2Model
if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
return CosyVoice3Model
raise TypeError('No valid model type found!')

View File

@@ -25,6 +25,33 @@ import torch
IGNORE_ID = -1
instruct_list = ["You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
"You are a helpful assistant. 请用东北话表达。<|endofprompt|>",
"You are a helpful assistant. 请用甘肃话表达。<|endofprompt|>",
"You are a helpful assistant. 请用贵州话表达。<|endofprompt|>",
"You are a helpful assistant. 请用河南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用湖北话表达。<|endofprompt|>",
"You are a helpful assistant. 请用湖南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用江西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用闽南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用宁夏话表达。<|endofprompt|>",
"You are a helpful assistant. 请用山西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用陕西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用山东话表达。<|endofprompt|>",
"You are a helpful assistant. 请用上海话表达。<|endofprompt|>",
"You are a helpful assistant. 请用四川话表达。<|endofprompt|>",
"You are a helpful assistant. 请用天津话表达。<|endofprompt|>",
"You are a helpful assistant. 请用云南话表达。<|endofprompt|>",
"You are a helpful assistant. Please say a sentence as loudly as possible.<|endofprompt|>",
"You are a helpful assistant. Please say a sentence in a very soft voice.<|endofprompt|>",
"You are a helpful assistant. 请用尽可能慢地语速说一句话。<|endofprompt|>",
"You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常开心地说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常伤心地说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常生气地说一句话。<|endofprompt|>",
"You are a helpful assistant. 我想体验一下小猪佩奇风格,可以吗?<|endofprompt|>",
"You are a helpful assistant. 你可以尝试用机器人的方式解答吗?<|endofprompt|>"]
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
@@ -130,12 +157,12 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
top_ids = indices[prob.multinomial(1, replacement=True)].item()
return top_ids
def random_sampling(weighted_scores, decoded_tokens, sampling):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item()
return top_ids

View File

@@ -41,11 +41,11 @@ def read_json_lists(list_file):
return results
def load_wav(wav, target_sr):
def load_wav(wav, target_sr, min_sr=16000):
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
speech = speech.mean(dim=0, keepdim=True)
if sample_rate != target_sr:
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech
@@ -88,30 +88,18 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
logging.info("Succesfully convert onnx to trt...")
# NOTE do not support bistream inference as only speech token embedding/head is kept
def export_cosyvoice2_vllm(model, model_path, device):
if os.path.exists(model_path):
return
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
vocab_size = model.speech_embedding.num_embeddings
feature_size = model.speech_embedding.embedding_dim
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
dtype = torch.bfloat16
# lm_head
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
with torch.no_grad():
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
new_lm_head.weight[vocab_size:] = 0
new_lm_head.bias[vocab_size:] = 0
model.llm.model.lm_head = new_lm_head
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
use_bias = True if model.llm_decoder.bias is not None else False
model.llm.model.lm_head = model.llm_decoder
# embed_tokens
embed_tokens = model.llm.model.model.embed_tokens
with torch.no_grad():
new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
new_codec_embed.weight[vocab_size:] = 0
model.llm.model.set_input_embeddings(new_codec_embed)
model.llm.model.set_input_embeddings(model.speech_embedding)
model.llm.model.to(device)
model.llm.model.to(dtype)
tmp_vocab_size = model.llm.model.config.vocab_size
@@ -119,11 +107,12 @@ def export_cosyvoice2_vllm(model, model_path, device):
del model.llm.model.generation_config.eos_token_id
del model.llm.model.config.bos_token_id
del model.llm.model.config.eos_token_id
model.llm.model.config.vocab_size = pad_vocab_size
model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
model.llm.model.config.tie_word_embeddings = False
model.llm.model.config.use_bias = True
model.llm.model.config.use_bias = use_bias
model.llm.model.save_pretrained(model_path)
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
if use_bias is True:
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
model.llm.model.config.vocab_size = tmp_vocab_size
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
model.llm.model.set_input_embeddings(embed_tokens)

View File

@@ -53,7 +53,7 @@ def init_distributed(args):
def init_dataset_and_dataloader(args, configs, gan, dpo):
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader = DataLoader(train_dataset,
@@ -164,18 +164,18 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
raise ValueError("unknown scheduler: " + configs['train_conf'])
if configs['train_conf']['optim_d'] == 'adam':
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
elif configs['train_conf']['optim_d'] == 'adamw':
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler_d'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler_d = ConstantLR(optimizer_d)