Merge pull request #1140 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
This commit is contained in:
Xiang Lyu
2025-04-07 23:04:37 +08:00
committed by GitHub
33 changed files with 1901 additions and 298 deletions

View File

@@ -128,7 +128,7 @@ import torchaudio
**CosyVoice2 Usage** **CosyVoice2 Usage**
```python ```python
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False) cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
# zero_shot usage # zero_shot usage

View File

@@ -75,10 +75,11 @@ def main():
print('Processing {}'.format(path)) print('Processing {}'.format(path))
states = torch.load(path, map_location=torch.device('cpu')) states = torch.load(path, map_location=torch.device('cpu'))
for k in states.keys(): for k in states.keys():
if k not in avg.keys(): if k not in ['step', 'epoch']:
avg[k] = states[k].clone() if k not in avg.keys():
else: avg[k] = states[k].clone()
avg[k] += states[k] else:
avg[k] += states[k]
# average # average
for k in avg.keys(): for k in avg.keys():
if avg[k] is not None: if avg[k] is not None:

View File

@@ -24,6 +24,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import logging
def get_args(): def get_args():
@@ -60,7 +61,8 @@ def main():
model = CosyVoice(args.model_dir) model = CosyVoice(args.model_dir)
except Exception: except Exception:
try: try:
model = CosyVoice2(args.model_dir) # NOTE set use_flow_cache=True when export jit for cache inference
model = CosyVoice2(args.model_dir, use_flow_cache=True)
except Exception: except Exception:
raise TypeError('no valid model_type!') raise TypeError('no valid model_type!')
@@ -71,6 +73,7 @@ def main():
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir)) script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(llm_text_encoder.half()) script = get_optimized_script(llm_text_encoder.half())
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export llm_text_encoder')
# 2. export llm llm # 2. export llm llm
llm_llm = model.model.llm.llm llm_llm = model.model.llm.llm
@@ -78,13 +81,23 @@ def main():
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir)) script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
script = get_optimized_script(llm_llm.half(), ['forward_chunk']) script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
logging.info('successfully export llm_llm')
# 3. export flow encoder # 3. export flow encoder
flow_encoder = model.model.flow.encoder flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder) script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half()) script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else:
# 3. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder, ['forward_chunk'])
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half(), ['forward_chunk'])
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -28,6 +28,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import logging
def get_dummy_input(batch_size, seq_len, out_channels, device): def get_dummy_input(batch_size, seq_len, out_channels, device):
@@ -51,6 +52,7 @@ def get_args():
return args return args
@torch.no_grad()
def main(): def main():
args = get_args() args = get_args()
logging.basicConfig(level=logging.DEBUG, logging.basicConfig(level=logging.DEBUG,
@@ -60,56 +62,132 @@ def main():
model = CosyVoice(args.model_dir) model = CosyVoice(args.model_dir)
except Exception: except Exception:
try: try:
model = CosyVoice2(args.model_dir) # NOTE set use_flow_cache=True when export jit for cache inference
model = CosyVoice2(args.model_dir, use_flow_cache=True)
except Exception: except Exception:
raise TypeError('no valid model_type!') raise TypeError('no valid model_type!')
# 1. export flow decoder estimator if not isinstance(model, CosyVoice2):
estimator = model.model.flow.decoder.estimator # 1. export flow decoder estimator
estimator = model.model.flow.decoder.estimator
estimator.eval()
device = model.model.device device = model.model.device
batch_size, seq_len = 2, 256 batch_size, seq_len = 2, 256
out_channels = model.model.flow.decoder.estimator.out_channels out_channels = model.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
torch.onnx.export( torch.onnx.export(
estimator, estimator,
(x, mask, mu, t, spks, cond), (x, mask, mu, t, spks, cond),
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
export_params=True, export_params=True,
opset_version=18, opset_version=18,
do_constant_folding=True, do_constant_folding=True,
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
output_names=['estimator_out'], output_names=['estimator_out'],
dynamic_axes={ dynamic_axes={
'x': {2: 'seq_len'}, 'x': {2: 'seq_len'},
'mask': {2: 'seq_len'}, 'mask': {2: 'seq_len'},
'mu': {2: 'seq_len'}, 'mu': {2: 'seq_len'},
'cond': {2: 'seq_len'}, 'cond': {2: 'seq_len'},
'estimator_out': {2: 'seq_len'}, 'estimator_out': {2: 'seq_len'},
} }
) )
# 2. test computation consistency # 2. test computation consistency
option = onnxruntime.SessionOptions() option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1 option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
sess_options=option, providers=providers) sess_options=option, providers=providers)
for _ in tqdm(range(10)): for _ in tqdm(range(10)):
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
output_pytorch = estimator(x, mask, mu, t, spks, cond) output_pytorch = estimator(x, mask, mu, t, spks, cond)
ort_inputs = { ort_inputs = {
'x': x.cpu().numpy(), 'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(), 'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(), 'mu': mu.cpu().numpy(),
't': t.cpu().numpy(), 't': t.cpu().numpy(),
'spks': spks.cpu().numpy(), 'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy() 'cond': cond.cpu().numpy()
} }
output_onnx = estimator_onnx.run(None, ort_inputs)[0] output_onnx = estimator_onnx.run(None, ort_inputs)[0]
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
logging.info('successfully export estimator')
else:
# 1. export flow decoder estimator
estimator = model.model.flow.decoder.estimator
estimator.forward = estimator.forward_chunk
estimator.eval()
device = model.model.device
batch_size, seq_len = 2, 256
out_channels = model.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
cache = model.model.init_flow_cache()['decoder_cache']
cache.pop('offset')
cache = {k: v[0] for k, v in cache.items()}
torch.onnx.export(
estimator,
(x, mask, mu, t, spks, cond,
cache['down_blocks_conv_cache'],
cache['down_blocks_kv_cache'],
cache['mid_blocks_conv_cache'],
cache['mid_blocks_kv_cache'],
cache['up_blocks_conv_cache'],
cache['up_blocks_kv_cache'],
cache['final_blocks_conv_cache']),
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
dynamic_axes={
'x': {2: 'seq_len'},
'mask': {2: 'seq_len'},
'mu': {2: 'seq_len'},
'cond': {2: 'seq_len'},
'down_blocks_kv_cache': {3: 'cache_in_len'},
'mid_blocks_kv_cache': {3: 'cache_in_len'},
'up_blocks_kv_cache': {3: 'cache_in_len'},
'estimator_out': {2: 'seq_len'},
'down_blocks_kv_cache_out': {3: 'cache_out_len'},
'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
'up_blocks_kv_cache_out': {3: 'cache_out_len'},
}
)
# 2. test computation consistency
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
sess_options=option, providers=providers)
for _ in tqdm(range(10)):
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
cache = model.model.init_flow_cache()['decoder_cache']
cache.pop('offset')
cache = {k: v[0] for k, v in cache.items()}
output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
ort_inputs = {
'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(),
't': t.cpu().numpy(),
'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy(),
}
output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
for i, j in zip(output_pytorch, output_onnx):
torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
logging.info('successfully export estimator')
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,10 +0,0 @@
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
TRT_DIR=<YOUR_TRT_DIR>
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw

View File

@@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
import torchaudio import torchaudio
from hyperpyyaml import load_hyperpyyaml from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.dataset.dataset import Dataset from cosyvoice.dataset.dataset import Dataset
@@ -33,6 +33,7 @@ def get_args():
parser.add_argument('--prompt_data', required=True, help='prompt data file') parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
parser.add_argument('--tts_text', required=True, help='tts input file') parser.add_argument('--tts_text', required=True, help='tts input file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--llm_model', required=True, help='llm model file') parser.add_argument('--llm_model', required=True, help='llm model file')
parser.add_argument('--flow_model', required=True, help='flow model file') parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--hifigan_model', required=True, help='hifigan model file') parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
@@ -59,16 +60,25 @@ def main():
# Init cosyvoice models from configs # Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available() use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu') device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f: try:
configs = load_hyperpyyaml(f) with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16=False)
except Exception:
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16=False)
except Exception:
raise TypeError('no valid model_type!')
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
model.load(args.llm_model, args.flow_model, args.hifigan_model) model.load(args.llm_model, args.flow_model, args.hifigan_model)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
sample_rate = configs['sample_rate']
del configs del configs
os.makedirs(args.result_dir, exist_ok=True) os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp') fn = os.path.join(args.result_dir, 'wav.scp')
@@ -104,7 +114,7 @@ def main():
tts_speeches = torch.concat(tts_speeches, dim=1) tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0]) tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050) torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
f.write('{} {}\n'.format(tts_key, tts_fn)) f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush() f.flush()
f.close() f.close()

View File

@@ -46,6 +46,7 @@ def get_args():
parser.add_argument('--config', required=True, help='config file') parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file') parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file') parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--checkpoint', help='checkpoint model') parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir') parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir', parser.add_argument('--tensorboard_dir',
@@ -97,8 +98,12 @@ def main():
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True: if gan is True:
override_dict.pop('hift') override_dict.pop('hift')
with open(args.config, 'r') as f: try:
configs = load_hyperpyyaml(f, overrides=override_dict) with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
except Exception:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True: if gan is True:
configs['train_conf'] = configs['train_conf_gan'] configs['train_conf'] = configs['train_conf_gan']
configs['train_conf'].update(vars(args)) configs['train_conf'].update(vars(args))

View File

@@ -32,7 +32,10 @@ class CosyVoice:
self.fp16 = fp16 self.fp16 = fp16
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir) model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f) configs = load_hyperpyyaml(f)
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir) assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
@@ -126,13 +129,16 @@ class CosyVoice:
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
self.instruct = True if '-Instruct' in model_dir else False self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
self.fp16 = fp16 self.fp16 = fp16
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir) model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir) assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
@@ -145,9 +151,9 @@ class CosyVoice2(CosyVoice):
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
self.model.load('{}/llm.pt'.format(model_dir), self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir), '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir)) '{}/hift.pt'.format(model_dir))
if load_jit: if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))

View File

@@ -36,16 +36,12 @@ class CosyVoiceModel:
self.flow = flow self.flow = flow
self.hift = hift self.hift = hift
self.fp16 = fp16 self.fp16 = fp16
self.llm.fp16 = fp16
self.flow.fp16 = fp16
if self.fp16 is True: if self.fp16 is True:
self.llm.half() self.llm.half()
self.flow.half() self.flow.half()
self.token_min_hop_len = 2 * self.flow.input_frame_rate self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20 self.token_overlap_len = 20
# here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
self.flow.decoder.estimator.static_chunk_size = 0
# mel fade in out # mel fade in out
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256) self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
self.mel_window = np.hamming(2 * self.mel_overlap_len) self.mel_window = np.hamming(2 * self.mel_overlap_len)
@@ -87,19 +83,25 @@ class CosyVoiceModel:
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!' assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model): if not os.path.exists(flow_decoder_estimator_model):
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16) convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
if os.path.getsize(flow_decoder_estimator_model) == 0: if os.path.getsize(flow_decoder_estimator_model) == 0:
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model)) raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
del self.flow.decoder.estimator del self.flow.decoder.estimator
import tensorrt as trt import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f: with open(flow_decoder_estimator_model, 'rb') as f:
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
if self.flow.decoder.estimator_engine is None: assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context: with self.llm_context, torch.cuda.amp.autocast(self.fp16):
if isinstance(text, Generator): if isinstance(text, Generator):
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
for i in self.llm.inference_bistream(text=text, for i in self.llm.inference_bistream(text=text,
@@ -121,15 +123,15 @@ class CosyVoiceModel:
self.llm_end_dict[uuid] = True self.llm_end_dict[uuid] = True
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device), with torch.cuda.amp.autocast(self.fp16):
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
prompt_token=prompt_token.to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device),
prompt_feat=prompt_feat.to(self.device), prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), prompt_feat=prompt_feat.to(self.device),
embedding=embedding.to(self.device), prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
flow_cache=self.flow_cache_dict[uuid]) embedding=embedding.to(self.device),
self.flow_cache_dict[uuid] = flow_cache flow_cache=self.flow_cache_dict[uuid])
# mel overlap fade in out # mel overlap fade in out
if self.mel_overlap_dict[uuid].shape[2] != 0: if self.mel_overlap_dict[uuid].shape[2] != 0:
@@ -276,6 +278,7 @@ class CosyVoiceModel:
self.llm_end_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -285,49 +288,88 @@ class CosyVoice2Model(CosyVoiceModel):
llm: torch.nn.Module, llm: torch.nn.Module,
flow: torch.nn.Module, flow: torch.nn.Module,
hift: torch.nn.Module, hift: torch.nn.Module,
fp16: bool): fp16: bool,
use_flow_cache: bool):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm self.llm = llm
self.flow = flow self.flow = flow
self.hift = hift self.hift = hift
self.fp16 = fp16 self.fp16 = fp16
self.llm.fp16 = fp16 self.use_flow_cache = use_flow_cache
self.flow.fp16 = fp16
if self.fp16 is True: if self.fp16 is True:
self.llm.half() self.llm.half()
self.flow.half() self.flow.half()
self.token_hop_len = 2 * self.flow.input_frame_rate # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache self.token_hop_len = 25
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
# hift cache # hift cache
self.mel_cache_len = 8 self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480) self.source_cache_len = int(self.mel_cache_len * 480)
# speech fade in out # speech fade in out
self.speech_window = np.hamming(2 * self.source_cache_len) self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related # rtf and decoding related
self.stream_scale_factor = 1
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock() self.lock = threading.Lock()
# dict used to store session related variable # dict used to store session related variable
self.tts_speech_token_dict = {} self.tts_speech_token_dict = {}
self.llm_end_dict = {} self.llm_end_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
def init_flow_cache(self):
encoder_cache = {'offset': 0,
'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
'upsample_offset': 0,
'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
decoder_cache = {'offset': 0,
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device),
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
if self.fp16 is True:
for cache in [encoder_cache, decoder_cache]:
for k, v in cache.items():
if isinstance(v, torch.Tensor):
cache[k] = v.half()
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
return cache
def trim_flow_cache(self, cache):
if self.flow_decoder_required_cache_size > 0:
cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
return cache
def load_jit(self, flow_encoder_model): def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): def get_trt_kwargs(self):
tts_mel, _ = self.flow.inference(token=token.to(self.device), min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
prompt_token=prompt_token.to(self.device), max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
prompt_feat=prompt_feat.to(self.device), assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
embedding=embedding.to(self.device),
finalize=finalize) def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] with torch.cuda.amp.autocast(self.fp16):
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
cache=self.flow_cache_dict[uuid],
finalize=finalize)
self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid])
# append hift cache # append hift cache
if self.hift_cache_dict[uuid] is not None: if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
@@ -359,27 +401,34 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread # this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1()) this_uuid = str(uuid.uuid1())
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
if self.use_flow_cache is True:
flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
with self.lock: with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None self.hift_cache_dict[this_uuid] = None
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start() p.start()
if stream is True: if stream is True:
token_offset = 0
while True: while True:
time.sleep(0.1) time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token, this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token, prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat, prompt_feat=prompt_speech_feat,
embedding=flow_embedding, embedding=flow_embedding,
uuid=this_uuid, uuid=this_uuid,
token_offset=token_offset,
finalize=False) finalize=False)
token_offset += self.token_hop_len # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len: with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
break break
p.join() p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
@@ -389,7 +438,6 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_feat=prompt_speech_feat, prompt_feat=prompt_speech_feat,
embedding=flow_embedding, embedding=flow_embedding,
uuid=this_uuid, uuid=this_uuid,
token_offset=token_offset,
finalize=True) finalize=True)
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}
else: else:
@@ -401,11 +449,12 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_feat=prompt_speech_feat, prompt_feat=prompt_speech_feat,
embedding=flow_embedding, embedding=flow_embedding,
uuid=this_uuid, uuid=this_uuid,
token_offset=0,
finalize=True, finalize=True,
speed=speed) speed=speed)
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}
with self.lock: with self.lock:
self.tts_speech_token_dict.pop(this_uuid) self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -196,8 +196,8 @@ def compute_f0(data, sample_rate, hop_size, mode='train'):
assert 'text_token' in sample assert 'text_token' in sample
waveform = sample['speech'] waveform = sample['speech']
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
if sum(_f0 != 0) < 5: # this happens when the algorithm fails if sum(_f0 != 0) < 5: # this happens when the algorithm fails
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate) f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1) f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
sample['pitch_feat'] = f0 sample['pitch_feat'] = f0

View File

@@ -11,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple, Optional, Dict, Any
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import pack, rearrange, repeat from einops import pack, rearrange, repeat
from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
from cosyvoice.utils.common import mask_to_bias from cosyvoice.utils.common import mask_to_bias
from cosyvoice.utils.mask import add_optional_chunk_mask from cosyvoice.utils.mask import add_optional_chunk_mask
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
class Transpose(torch.nn.Module): class Transpose(torch.nn.Module):
@@ -27,34 +29,11 @@ class Transpose(torch.nn.Module):
self.dim0 = dim0 self.dim0 = dim0
self.dim1 = dim1 self.dim1 = dim1
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
x = torch.transpose(x, self.dim0, self.dim1) x = torch.transpose(x, self.dim0, self.dim1)
return x return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class CausalConv1d(torch.nn.Conv1d): class CausalConv1d(torch.nn.Conv1d):
def __init__( def __init__(
self, self,
@@ -76,12 +55,339 @@ class CausalConv1d(torch.nn.Conv1d):
padding_mode=padding_mode, padding_mode=padding_mode,
device=device, dtype=dtype) device=device, dtype=dtype)
assert stride == 1 assert stride == 1
self.causal_padding = (kernel_size - 1, 0) self.causal_padding = kernel_size - 1
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
x = F.pad(x, self.causal_padding) if cache.size(2) == 0:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
else:
assert cache.size(2) == self.causal_padding
x = torch.concat([cache, x], dim=2)
cache = x[:, :, -self.causal_padding:]
x = super(CausalConv1d, self).forward(x) x = super(CausalConv1d, self).forward(x)
return x return x, cache
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
output, cache = self.block[0](x * mask, cache)
for i in range(1, len(self.block)):
output = self.block[i](output)
return output * mask, cache
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
h, block1_cache = self.block1(x, mask, block1_cache)
h += self.mlp(time_emb).unsqueeze(-1)
h, block2_cache = self.block2(h, mask, block2_cache)
output = h + self.res_conv(x * mask)
return output, block1_cache, block2_cache
class CausalAttnProcessor2_0(AttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
super(CausalAttnProcessor2_0, self).__init__()
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
`scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
# NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key_cache = attn.to_k(encoder_hidden_states)
value_cache = attn.to_v(encoder_hidden_states)
# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
if cache.size(0) != 0:
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
else:
key, value = key_cache, value_cache
cache = torch.stack([key_cache, value_cache], dim=3)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states, cache
@maybe_allow_in_graph
class CausalAttention(Attention):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor2_0"] = None,
out_dim: int = None,
):
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
_from_deprecated_attn_block, processor, out_dim)
processor = CausalAttnProcessor2_0()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
**cross_attention_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cache=cache,
**cross_attention_kwargs,
)
@maybe_allow_in_graph
class CausalBasicTransformerBlock(BasicTransformerBlock):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
):
super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
cross_attention_dim, activation_fn, num_embeds_ada_norm,
attention_bias, only_cross_attention, double_self_attention,
upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
self.attn1 = CausalAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
) -> Tuple[torch.Tensor, torch.Tensor]:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output, cache = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
cache=cache,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
{self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states, cache
class ConditionalDecoder(nn.Module): class ConditionalDecoder(nn.Module):
@@ -89,7 +395,6 @@ class ConditionalDecoder(nn.Module):
self, self,
in_channels, in_channels,
out_channels, out_channels,
causal=False,
channels=(256, 256), channels=(256, 256),
dropout=0.05, dropout=0.05,
attention_head_dim=64, attention_head_dim=64,
@@ -106,7 +411,7 @@ class ConditionalDecoder(nn.Module):
channels = tuple(channels) channels = tuple(channels)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.causal = causal
self.time_embeddings = SinusoidalPosEmb(in_channels) self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4 time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding( self.time_mlp = TimestepEmbedding(
@@ -123,8 +428,7 @@ class ConditionalDecoder(nn.Module):
input_channel = output_channel input_channel = output_channel
output_channel = channels[i] output_channel = channels[i]
is_last = i == len(channels) - 1 is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [
BasicTransformerBlock( BasicTransformerBlock(
@@ -138,16 +442,14 @@ class ConditionalDecoder(nn.Module):
] ]
) )
downsample = ( downsample = (
Downsample1D(output_channel) if not is_last else Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
) )
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks): for _ in range(num_mid_blocks):
input_channel = channels[-1] input_channel = channels[-1]
out_channels = channels[-1] out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [
@@ -169,11 +471,7 @@ class ConditionalDecoder(nn.Module):
input_channel = channels[i] * 2 input_channel = channels[i] * 2
output_channel = channels[i + 1] output_channel = channels[i + 1]
is_last = i == len(channels) - 2 is_last = i == len(channels) - 2
resnet = CausalResnetBlock1D( resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
) if self.causal else ResnetBlock1D(
dim=input_channel, dim=input_channel,
dim_out=output_channel, dim_out=output_channel,
time_emb_dim=time_embed_dim, time_emb_dim=time_embed_dim,
@@ -193,10 +491,10 @@ class ConditionalDecoder(nn.Module):
upsample = ( upsample = (
Upsample1D(output_channel, use_conv_transpose=True) Upsample1D(output_channel, use_conv_transpose=True)
if not is_last if not is_last
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) else nn.Conv1d(output_channel, output_channel, 3, padding=1)
) )
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights() self.initialize_weights()
@@ -214,7 +512,7 @@ class ConditionalDecoder(nn.Module):
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None): def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
"""Forward pass of the UNet1DConditional model. """Forward pass of the UNet1DConditional model.
Args: Args:
@@ -249,9 +547,8 @@ class ConditionalDecoder(nn.Module):
mask_down = masks[-1] mask_down = masks[-1]
x = resnet(x, mask_down, t) x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x = transformer_block( x = transformer_block(
hidden_states=x, hidden_states=x,
@@ -268,9 +565,8 @@ class ConditionalDecoder(nn.Module):
for resnet, transformer_blocks in self.mid_blocks: for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t) x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x = transformer_block( x = transformer_block(
hidden_states=x, hidden_states=x,
@@ -285,9 +581,8 @@ class ConditionalDecoder(nn.Module):
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t) x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x = transformer_block( x = transformer_block(
hidden_states=x, hidden_states=x,
@@ -299,3 +594,309 @@ class ConditionalDecoder(nn.Module):
x = self.final_block(x, mask_up) x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up) output = self.final_proj(x * mask_up)
return output * mask return output * mask
class CausalConditionalDecoder(ConditionalDecoder):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
static_chunk_size=50,
num_decoding_left_chunks=2,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
torch.nn.Module.__init__(self)
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
CausalBasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
CausalBasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = CausalResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
CausalBasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else CausalConv1d(output_channel, output_channel, 3)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = CausalBlock1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x, _, _ = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
else:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x, _ = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x, _ = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x, _, _ = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
else:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x, _ = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x, _, _ = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
else:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x, _ = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x, _ = upsample(x * mask_up)
x, _ = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
mask_down = masks[-1]
x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
attn_mask = mask_to_bias(attn_mask, x.dtype)
for i, transformer_block in enumerate(transformer_blocks):
x, down_blocks_kv_cache_new[index, i] = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
cache=down_blocks_kv_cache[index, i],
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
attn_mask = mask_to_bias(attn_mask, x.dtype)
for i, transformer_block in enumerate(transformer_blocks):
x, mid_blocks_kv_cache_new[index, i] = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
cache=mid_blocks_kv_cache[index, i]
)
x = rearrange(x, "b t c -> b c t").contiguous()
for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
attn_mask = mask_to_bias(attn_mask, x.dtype)
for i, transformer_block in enumerate(transformer_blocks):
x, up_blocks_kv_cache_new[index, i] = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
cache=up_blocks_kv_cache[index, i]
)
x = rearrange(x, "b t c -> b c t").contiguous()
x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
output = self.final_proj(x * mask_up)
return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache

View File

@@ -91,6 +91,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
conds = conds.transpose(1, 2) conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h) mask = (~make_pad_mask(feat_len)).to(h)
# NOTE this is unnecessary, feat/h already same shape
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
loss, _ = self.decoder.compute_loss( loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(), feat.transpose(1, 2).contiguous(),
@@ -111,16 +112,12 @@ class MaskedDiffWithXvec(torch.nn.Module):
prompt_feat_len, prompt_feat_len,
embedding, embedding,
flow_cache): flow_cache):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1 assert token.shape[0] == 1
# xvec projection # xvec projection
embedding = F.normalize(embedding, dim=1) embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding) embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text # concat speech token and prompt speech token
token_len1, token_len2 = prompt_token.shape[1], token.shape[1] token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
@@ -145,7 +142,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
cond=conds, cond=conds,
n_timesteps=10, n_timesteps=10,
prompt_len=mel_len1, prompt_len=mel_len1,
flow_cache=flow_cache cache=flow_cache
) )
feat = feat[:, :, mel_len1:] feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2 assert feat.shape[2] == mel_len2
@@ -190,6 +187,53 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.token_mel_ratio = token_mel_ratio self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len self.pre_lookahead_len = pre_lookahead_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
h = self.encoder_proj(h)
# get conditions
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
streaming=streaming,
)
return {'loss': loss}
@torch.inference_mode() @torch.inference_mode()
def inference(self, def inference(self,
token, token,
@@ -199,11 +243,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
prompt_feat, prompt_feat,
prompt_feat_len, prompt_feat_len,
embedding, embedding,
cache,
finalize): finalize):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1 assert token.shape[0] == 1
# xvec projection # xvec projection
embedding = F.normalize(embedding, dim=1) embedding = F.normalize(embedding, dim=1)
@@ -215,9 +256,17 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
token = self.input_embedding(torch.clamp(token, min=0)) * mask token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode # text encode
h, h_lengths = self.encoder(token, token_len) if finalize is True:
if finalize is False: h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] else:
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
cache['encoder_cache']['offset'] = encoder_cache[0]
cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h) h = self.encoder_proj(h)
@@ -227,13 +276,14 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
conds = conds.transpose(1, 2) conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder( feat, cache['decoder_cache'] = self.decoder(
mu=h.transpose(1, 2).contiguous(), mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1), mask=mask.unsqueeze(1),
spks=embedding, spks=embedding,
cond=conds, cond=conds,
n_timesteps=10 n_timesteps=10,
cache=cache['decoder_cache']
) )
feat = feat[:, :, mel_len1:] feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2 assert feat.shape[2] == mel_len2
return feat.float(), None return feat.float(), cache

View File

@@ -34,7 +34,7 @@ class ConditionalCFM(BASECFM):
self.lock = threading.Lock() self.lock = threading.Lock()
@torch.inference_mode() @torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
"""Forward diffusion """Forward diffusion
Args: Args:
@@ -54,19 +54,19 @@ class ConditionalCFM(BASECFM):
""" """
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
cache_size = flow_cache.shape[2] cache_size = cache.shape[2]
# fix prompt and overlap part mu and z # fix prompt and overlap part mu and z
if cache_size != 0: if cache_size != 0:
z[:, :, :cache_size] = flow_cache[:, :, :, 0] z[:, :, :cache_size] = cache[:, :, :, 0]
mu[:, :, :cache_size] = flow_cache[:, :, :, 1] mu[:, :, :cache_size] = cache[:, :, :, 1]
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
flow_cache = torch.stack([z_cache, mu_cache], dim=-1) cache = torch.stack([z_cache, mu_cache], dim=-1)
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine': if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
def solve_euler(self, x, t_span, mu, mask, spks, cond): def solve_euler(self, x, t_span, mu, mask, spks, cond):
""" """
@@ -123,7 +123,7 @@ class ConditionalCFM(BASECFM):
def forward_estimator(self, x, mask, mu, t, spks, cond): def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module): if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond) return self.estimator(x, mask, mu, t, spks, cond)
else: else:
with self.lock: with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2))) self.estimator.set_input_shape('x', (2, 80, x.size(2)))
@@ -133,16 +133,16 @@ class ConditionalCFM(BASECFM):
self.estimator.set_input_shape('spks', (2, 80)) self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2))) self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine # run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(), assert self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(), mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(), mu.contiguous().data_ptr(),
t.contiguous().data_ptr(), t.contiguous().data_ptr(),
spks.contiguous().data_ptr(), spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(), cond.contiguous().data_ptr(),
x.data_ptr()]) x.data_ptr()]) is True
return x return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None): def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
"""Computes diffusion loss """Computes diffusion loss
Args: Args:
@@ -179,7 +179,7 @@ class ConditionalCFM(BASECFM):
spks = spks * cfg_mask.view(-1, 1) spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1) cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y return loss, y
@@ -190,7 +190,7 @@ class CausalConditionalCFM(ConditionalCFM):
self.rand_noise = torch.randn([1, 80, 50 * 300]) self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode() @torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
"""Forward diffusion """Forward diffusion
Args: Args:
@@ -209,9 +209,136 @@ class CausalConditionalCFM(ConditionalCFM):
shape: (batch_size, n_feats, mel_timesteps) shape: (batch_size, n_feats, mel_timesteps)
""" """
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature offset = cache.pop('offset')
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
z = z[:, :, offset:]
offset += mu.size(2)
# fix prompt and overlap part mu and z # fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine': if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
cache['offset'] = offset
return mel, cache
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# estimator cache for each step
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x)
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
cache_step = {k: v[step - 1] for k, v in cache.items()}
dphi_dt, cache_step = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in,
cache_step
)
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
down_blocks_kv_cache_new[step - 1] = cache_step[1]
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
mid_blocks_kv_cache_new[step - 1] = cache_step[3]
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
up_blocks_kv_cache_new[step - 1] = cache_step[5]
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4)
cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4)
cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4)
return sol[-1].float(), cache
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
if isinstance(self.estimator, torch.nn.Module):
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
else:
with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
# run trt engine
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
x.data_ptr(),
cache['down_blocks_conv_cache'].data_ptr(),
down_blocks_kv_cache_out.data_ptr(),
cache['mid_blocks_conv_cache'].data_ptr(),
mid_blocks_kv_cache_out.data_ptr(),
cache['up_blocks_conv_cache'].data_ptr(),
up_blocks_kv_cache_out.data_ptr(),
cache['final_blocks_conv_cache'].data_ptr()]) is True
cache = (cache['down_blocks_conv_cache'],
down_blocks_kv_cache_out,
cache['mid_blocks_conv_cache'],
mid_blocks_kv_cache_out,
cache['up_blocks_conv_cache'],
up_blocks_kv_cache_out,
cache['final_blocks_conv_cache'])
return x, cache

View File

@@ -51,6 +51,7 @@ class InterpolateRegulator(nn.Module):
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
# x in (B, T, D) # x in (B, T, D)
if x2.shape[1] > 40: if x2.shape[1] > 40:
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')

View File

@@ -1,10 +1,16 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm import torch.nn.functional as F
try:
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
except ImportError:
from torch.nn.utils import weight_norm, spectral_norm
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from einops import rearrange from einops import rearrange
from torchaudio.transforms import Spectrogram from torchaudio.transforms import Spectrogram
LRELU_SLOPE = 0.1
class MultipleDiscriminator(nn.Module): class MultipleDiscriminator(nn.Module):
def __init__( def __init__(
@@ -138,3 +144,87 @@ class DiscriminatorR(nn.Module):
x += h x += h
return x, fmap return x, fmap
class MultiResSpecDiscriminator(torch.nn.Module):
def __init__(self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window"):
super(MultiResSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for _, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
return torch.abs(x_stft).transpose(2, 1)
class SpecDiscriminator(nn.Module):
"""docstring for Discriminator."""
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
super(SpecDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.discriminators = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
])
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
def forward(self, y):
fmap = []
y = y.squeeze(1)
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
y = y.unsqueeze(1)
for _, d in enumerate(self.discriminators):
y = d(y)
y = F.leaky_relu(y, LRELU_SLOPE)
fmap.append(y)
y = self.out(y)
fmap.append(y)
return torch.flatten(y, 1, -1), fmap

View File

@@ -13,7 +13,10 @@
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm try:
from torch.nn.utils.parametrizations import weight_norm
except ImportError:
from torch.nn.utils import weight_norm
class ConvRNNF0Predictor(nn.Module): class ConvRNNF0Predictor(nn.Module):

View File

@@ -23,7 +23,10 @@ import torch.nn.functional as F
from torch.nn import Conv1d from torch.nn import Conv1d
from torch.nn import ConvTranspose1d from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm try:
from torch.nn.utils.parametrizations import weight_norm
except ImportError:
from torch.nn.utils import weight_norm
from torch.distributions.uniform import Uniform from torch.distributions.uniform import Uniform
from cosyvoice.transformer.activation import Snake from cosyvoice.transformer.activation import Snake

View File

@@ -41,7 +41,7 @@ class HiFiGan(nn.Module):
loss_fm = feature_loss(fmap_rs, fmap_gs) loss_fm = feature_loss(fmap_rs, fmap_gs)
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
if self.tpr_loss_weight != 0: if self.tpr_loss_weight != 0:
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
else: else:
loss_tpr = torch.zeros(1).to(device) loss_tpr = torch.zeros(1).to(device)
loss_f0 = F.l1_loss(generated_f0, pitch_feat) loss_f0 = F.l1_loss(generated_f0, pitch_feat)
@@ -56,7 +56,7 @@ class HiFiGan(nn.Module):
with torch.no_grad(): with torch.no_grad():
generated_speech, generated_f0 = self.generator(batch, device) generated_speech, generated_f0 = self.generator(batch, device)
# 2. calculate discriminator outputs # 2. calculate discriminator outputs
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
# 3. calculate discriminator losses, tpr losses [Optional] # 3. calculate discriminator losses, tpr losses [Optional]
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
if self.tpr_loss_weight != 0: if self.tpr_loss_weight != 0:

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
from typing import Dict, Optional, Callable, List, Generator from typing import Dict, Optional, Callable, List, Generator
import torch import torch
from torch import nn from torch import nn
@@ -21,6 +22,7 @@ from cosyvoice.utils.common import IGNORE_ID
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask
class TransformerLM(torch.nn.Module): class TransformerLM(torch.nn.Module):
@@ -169,9 +171,6 @@ class TransformerLM(torch.nn.Module):
max_token_text_ratio: float = 20, max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2, min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
if self.fp16 is True:
embedding = embedding.half()
device = text.device device = text.device
text = torch.concat([prompt_text, text], dim=1) text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len text_len += prompt_text_len
@@ -229,6 +228,17 @@ class Qwen2Encoder(torch.nn.Module):
super().__init__() super().__init__()
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T)
outs = self.model(
inputs_embeds=xs,
attention_mask=masks,
output_hidden_states=True,
return_dict=True,
)
return outs.hidden_states[-1], masks.unsqueeze(1)
def forward_one_step(self, xs, masks, cache=None): def forward_one_step(self, xs, masks, cache=None):
input_masks = masks[:, -1, :] input_masks = masks[:, -1, :]
outs = self.model( outs = self.model(
@@ -283,6 +293,82 @@ class Qwen2LM(TransformerLM):
self.sampling = sampling self.sampling = sampling
self.mix_ratio = mix_ratio self.mix_ratio = mix_ratio
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
lm_target, lm_input = [], []
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
for i in range(len(text_token)):
# bistream sequence
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [], []
this_lm_target.append(IGNORE_ID)
this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
if len(this_text_token) == self.mix_ratio[0]:
assert len(this_speech_token) == self.mix_ratio[1]
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
this_lm_target += this_speech_token
this_lm_target.append(self.speech_token_size + 2)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
else:
this_lm_target += [-1] * len(this_text_token)
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
this_lm_target.append(self.speech_token_size)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
# unistream sequence
else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target)
lm_input.append(this_lm_input)
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
return lm_target, lm_input, lm_input_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target.to(device))
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
@torch.inference_mode() @torch.inference_mode()
def inference( def inference(
self, self,
@@ -393,8 +479,8 @@ class Qwen2LM(TransformerLM):
while True: while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input, y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache) cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index: if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2 top_ids = self.speech_token_size + 2

View File

@@ -287,8 +287,16 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
Returns: Returns:
torch.Tensor: Corresponding encoding torch.Tensor: Corresponding encoding
""" """
pos_emb = self.pe[ # How to subscript a Union type:
:, # https://github.com/pytorch/pytorch/issues/69434
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, if isinstance(offset, int):
] pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
]
elif isinstance(offset, torch.Tensor):
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
]
return pos_emb return pos_emb

View File

@@ -56,11 +56,16 @@ class Upsample1D(nn.Module):
# In this mode, first repeat interpolate, than conv with stride=1 # In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) if conv_cache.size(2) == 0:
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
else:
assert conv_cache.size(2) == self.stride * 2
outputs = torch.concat([conv_cache, outputs], dim=2)
conv_cache_new = outputs[:, :, -self.stride * 2:]
outputs = self.conv(outputs) outputs = self.conv(outputs)
return outputs, input_lengths * self.stride return outputs, input_lengths * self.stride, conv_cache_new
class PreLookaheadLayer(nn.Module): class PreLookaheadLayer(nn.Module):
@@ -78,22 +83,32 @@ class PreLookaheadLayer(nn.Module):
kernel_size=3, stride=1, padding=0, kernel_size=3, stride=1, padding=0,
) )
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
inputs: (batch_size, seq_len, channels) inputs: (batch_size, seq_len, channels)
""" """
outputs = inputs.transpose(1, 2).contiguous() outputs = inputs.transpose(1, 2).contiguous()
context = context.transpose(1, 2).contiguous()
# look ahead # look ahead
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) if context.size(2) == 0:
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
else:
assert context.size(2) == self.pre_lookahead_len
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
outputs = F.leaky_relu(self.conv1(outputs)) outputs = F.leaky_relu(self.conv1(outputs))
# outputs # outputs
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0) if conv2_cache.size(2) == 0:
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
else:
assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
outputs = torch.concat([conv2_cache, outputs], dim=2)
conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
outputs = self.conv2(outputs) outputs = self.conv2(outputs)
outputs = outputs.transpose(1, 2).contiguous() outputs = outputs.transpose(1, 2).contiguous()
# residual connection # residual connection
outputs = outputs + inputs outputs = outputs + inputs
return outputs return outputs, conv2_cache_new
class UpsampleConformerEncoder(torch.nn.Module): class UpsampleConformerEncoder(torch.nn.Module):
@@ -240,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
xs_lens: torch.Tensor, xs_lens: torch.Tensor,
decoding_chunk_size: int = 0, decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1, num_decoding_left_chunks: int = -1,
streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor. """Embed positions in tensor.
@@ -270,30 +286,20 @@ class UpsampleConformerEncoder(torch.nn.Module):
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks) xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate) mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks, chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
# lookahead + conformer encoder # lookahead + conformer encoder
xs = self.pre_lookahead_layer(xs) xs, _ = self.pre_lookahead_layer(xs)
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
# upsample + conformer encoder # upsample + conformer encoder
xs = xs.transpose(1, 2).contiguous() xs = xs.transpose(1, 2).contiguous()
xs, xs_lens = self.up_layer(xs, xs_lens) xs, xs_lens, _ = self.up_layer(xs, xs_lens)
xs = xs.transpose(1, 2).contiguous() xs = xs.transpose(1, 2).contiguous()
T = xs.size(1) T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
xs, pos_emb, masks = self.up_embed(xs, masks) xs, pos_emb, masks = self.up_embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate) mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks, chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size * self.up_layer.stride,
num_decoding_left_chunks)
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before: if self.normalize_before:
@@ -316,3 +322,100 @@ class UpsampleConformerEncoder(torch.nn.Module):
for layer in self.up_encoders: for layer in self.up_encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs return xs
@torch.jit.export
def forward_chunk(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
offset: int = 0,
context: torch.Tensor = torch.zeros(0, 0, 0),
pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
upsample_offset: int = 0,
upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
NOTE(xcsong):
We pass the `__call__` method of the modules instead of `forward` to the
checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
"""
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
offset += xs.size(1)
tmp_masks = torch.ones(1,
context.size(1),
device=context.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if context.size(1) != 0:
context, _, _ = self.embed(context, tmp_masks, offset)
# lookahead + conformer encoder
xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
# NOTE in cache mode we do not need to call add_optional_chunk_mask
chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
encoders_kv_cache_list = []
for index, layer in enumerate(self.encoders):
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
encoders_kv_cache_list.append(encoders_kv_cache_new)
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
# upsample
xs = xs.transpose(1, 2).contiguous()
xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
xs = xs.transpose(1, 2).contiguous()
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
upsample_offset += xs.size(1)
# conformer encoder
chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
upsample_kv_cache_list = []
for index, layer in enumerate(self.up_encoders):
xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
upsample_kv_cache_list.append(upsample_kv_cache_new)
upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)

View File

@@ -47,13 +47,8 @@ def load_wav(wav, target_sr):
return speech return speech
def convert_onnx_to_trt(trt_model, onnx_model, fp16): def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
import tensorrt as trt import tensorrt as trt
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
logging.info("Converting onnx to trt...") logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO) logger = trt.Logger(trt.Logger.INFO)
@@ -72,8 +67,8 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
print(parser.get_error(error)) print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model)) raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes # set input shapes
for i in range(len(input_names)): for i in range(len(trt_kwargs['input_names'])):
profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i]) profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
# set input and output data type # set input and output data type
for i in range(network.num_inputs): for i in range(network.num_inputs):
@@ -87,3 +82,4 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
# save trt engine # save trt engine
with open(trt_model, "wb") as f: with open(trt_model, "wb") as f:
f.write(engine_bytes) f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...")

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import torch import torch
from cosyvoice.utils.file_utils import logging
''' '''
def subsequent_mask( def subsequent_mask(
size: int, size: int,
@@ -87,7 +86,7 @@ def subsequent_mask(
return mask return mask
def subsequent_chunk_mask_deprecated( def subsequent_chunk_mask(
size: int, size: int,
chunk_size: int, chunk_size: int,
num_left_chunks: int = -1, num_left_chunks: int = -1,
@@ -125,41 +124,6 @@ def subsequent_chunk_mask_deprecated(
return ret return ret
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
# actually this is not needed after we have inference cache implemented, will remove it later
pos_idx = torch.arange(size, device=device)
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
return ret
def add_optional_chunk_mask(xs: torch.Tensor, def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor, masks: torch.Tensor,
use_dynamic_chunk: bool, use_dynamic_chunk: bool,
@@ -233,8 +197,8 @@ def add_optional_chunk_mask(xs: torch.Tensor,
chunk_masks = masks chunk_masks = masks
assert chunk_masks.dtype == torch.bool assert chunk_masks.dtype == torch.bool
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
chunk_masks[chunk_masks.sum(dim=-1)==0] = True chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
return chunk_masks return chunk_masks

View File

@@ -286,11 +286,15 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
# optimizer.step(). # optimizer.step().
if torch.isfinite(grad_norm): if torch.isfinite(grad_norm):
scaler.step(optimizer) scaler.step(optimizer)
else:
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
scaler.update() scaler.update()
else: else:
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
if torch.isfinite(grad_norm): if torch.isfinite(grad_norm):
optimizer.step() optimizer.step()
else:
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
optimizer.zero_grad() optimizer.zero_grad()
scheduler.step() scheduler.step()
info_dict["lr"] = optimizer.param_groups[0]['lr'] info_dict["lr"] = optimizer.param_groups[0]['lr']
@@ -336,7 +340,7 @@ def log_per_save(writer, info_dict):
rank = int(os.environ.get('RANK', 0)) rank = int(os.environ.get('RANK', 0))
logging.info( logging.info(
'Epoch {} Step {} CV info lr {} {} rank {}'.format( 'Epoch {} Step {} CV info lr {} {} rank {}'.format(
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()]))) epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
if writer is not None: if writer is not None:
for k in ['epoch', 'lr']: for k in ['epoch', 'lr']:

View File

@@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift> generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
mel_spec_transform: [ mel_spec_transform: [
!ref <mel_spec_transform1> !ref <mel_spec_transform1>
] ]

View File

@@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift> generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
mel_spec_transform: [ mel_spec_transform: [
!ref <mel_spec_transform1> !ref <mel_spec_transform1>
] ]

View File

@@ -0,0 +1,233 @@
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1986]
__set_seed2: !apply:numpy.random.seed [1986]
__set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 24000
llm_input_size: 896
llm_output_size: 896
spk_embed_dim: 192
qwen_pretrain_path: ''
token_frame_rate: 25
token_mel_ratio: 2
# stream related params
chunk_size: 25 # streaming inference chunk size, in token
num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.Qwen2LM
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
speech_token_size: 6561
length_normalized_loss: True
lsm_weight: 0
mix_ratio: [5, 15]
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
pretrain_path: !ref <qwen_pretrain_path>
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
input_size: 512
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 6561
input_frame_rate: !ref <token_frame_rate>
only_mask_loss: True
token_mel_ratio: !ref <token_mel_ratio>
pre_lookahead_len: 3
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
output_size: 512
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
static_chunk_size: !ref <chunk_size>
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
in_channels: 320
out_channels: 80
channels: [256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 12
num_heads: 8
act_fn: 'gelu'
static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
num_decoding_left_chunks: !ref <num_decoding_left_chunks>
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 5, 3]
upsample_kernel_sizes: [16, 11, 7]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# gan related module
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1920
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 480
win_size: 1920
fmin: 0
fmax: null
center: False
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
mel_spec_transform: [
!ref <mel_spec_transform1>
]
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
token_path: !ref <qwen_pretrain_path>
skip_special_tokens: True
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 100
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
truncate: !name:cosyvoice.dataset.processor.truncate
truncate_length: 24480 # must be a multiplier of hop_size
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1920
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 480
win_size: 1920
fmin: 0
fmax: 8000
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate>
hop_size: 480
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 2000
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: False # change to True during sft
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
data_pipeline_gan: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <truncate>,
!ref <compute_fbank>,
!ref <compute_f0>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
# llm flow train conf
train_conf:
optim: adam
optim_conf:
lr: 1e-5 # change to 1e-5 during sft
scheduler: constantlr # change to constantlr during sft
scheduler_conf:
warmup_steps: 2500
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: -1
# gan train conf
train_conf_gan:
optim: adam
optim_conf:
lr: 0.0002 # use small lr for gan training
scheduler: constantlr
optim_d: adam
optim_conf_d:
lr: 0.0002 # use small lr for gan training
scheduler_d: constantlr
max_epoch: 200
grad_clip: 5
accum_grad: 1 # in gan training, accum_grad must be 1
log_interval: 100
save_per_step: -1

View File

@@ -0,0 +1,42 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 256,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients" : true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.001,
"weight_decay": 0.0001,
"torch_adam": true,
"adam_w_mode": true
}
}
}

View File

@@ -0,0 +1 @@
../cosyvoice/local

View File

@@ -0,0 +1,3 @@
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH

View File

@@ -0,0 +1,130 @@
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
. ./path.sh || exit 1;
stage=-1
stop_stage=3
data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--src_dir data/$x \
--des_dir data/$x/parquet
done
fi
# inference
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
# TODO consider remove bin/inference.py, or use similar initilization method as in readme
for mode in sft zero_shot; do
python cosyvoice/bin/inference.py --mode $mode \
--gpu 0 \
--config conf/cosyvoice2.yaml \
--prompt_data data/test-clean/parquet/data.list \
--prompt_utt2data data/test-clean/parquet/utt2data.list \
--tts_text `pwd`/tts_text.json \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--llm_model $pretrained_model_dir/llm.pt \
--flow_model $pretrained_model_dir/flow.pt \
--hifigan_model $pretrained_model_dir/hift.pt \
--result_dir `pwd`/exp/cosyvoice/test-clean/$mode
done
fi
# train llm
export CUDA_VISIBLE_DEVICES="0,1,2,3"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
job_id=1986
dist_backend="nccl"
num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
# NOTE will update llm/hift training later
for model in llm flow; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice2.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--model $model \
--checkpoint $pretrained_model_dir/$model.pt \
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
--tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--use_amp \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
done
fi
# average model
average_num=5
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
for model in llm flow hifigan; do
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python cosyvoice/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
--num ${average_num} \
--val_best
done
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
fi

View File

@@ -0,0 +1,5 @@
{
"1089_134686_000002_000000": [
"hello, my name is Jack. What is your name?"
]
}

View File

@@ -13,7 +13,7 @@ inflect==7.3.1
librosa==0.10.2 librosa==0.10.2
lightning==2.2.4 lightning==2.2.4
matplotlib==3.7.5 matplotlib==3.7.5
modelscope==1.15.0 modelscope==1.20.0
networkx==3.1 networkx==3.1
omegaconf==2.3.0 omegaconf==2.3.0
onnx==1.16.0 onnx==1.16.0
@@ -21,6 +21,7 @@ onnxruntime-gpu==1.18.0; sys_platform == 'linux'
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32' onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32'
openai-whisper==20231117 openai-whisper==20231117
protobuf==4.25 protobuf==4.25
pyarrow==18.1.0
pydantic==2.7.0 pydantic==2.7.0
pyworld==0.3.4 pyworld==0.3.4
rich==13.7.1 rich==13.7.1