diff --git a/README.md b/README.md index 62bdf1d..bd18016 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ import torchaudio **CosyVoice2 Usage** ```python -cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False) +cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False) # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference # zero_shot usage diff --git a/cosyvoice/bin/average_model.py b/cosyvoice/bin/average_model.py index d095dcd..b7140c1 100644 --- a/cosyvoice/bin/average_model.py +++ b/cosyvoice/bin/average_model.py @@ -75,10 +75,11 @@ def main(): print('Processing {}'.format(path)) states = torch.load(path, map_location=torch.device('cpu')) for k in states.keys(): - if k not in avg.keys(): - avg[k] = states[k].clone() - else: - avg[k] += states[k] + if k not in ['step', 'epoch']: + if k not in avg.keys(): + avg[k] = states[k].clone() + else: + avg[k] += states[k] # average for k in avg.keys(): if avg[k] is not None: diff --git a/cosyvoice/bin/export_jit.py b/cosyvoice/bin/export_jit.py index ddd486e..1e89005 100644 --- a/cosyvoice/bin/export_jit.py +++ b/cosyvoice/bin/export_jit.py @@ -24,6 +24,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import logging def get_args(): @@ -60,7 +61,8 @@ def main(): model = CosyVoice(args.model_dir) except Exception: try: - model = CosyVoice2(args.model_dir) + # NOTE set use_flow_cache=True when export jit for cache inference + model = CosyVoice2(args.model_dir, use_flow_cache=True) except Exception: raise TypeError('no valid model_type!') @@ -71,6 +73,7 @@ def main(): script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir)) script = get_optimized_script(llm_text_encoder.half()) script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) + logging.info('successfully export llm_text_encoder') # 2. export llm llm llm_llm = model.model.llm.llm @@ -78,13 +81,23 @@ def main(): script.save('{}/llm.llm.fp32.zip'.format(args.model_dir)) script = get_optimized_script(llm_llm.half(), ['forward_chunk']) script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) + logging.info('successfully export llm_llm') - # 3. export flow encoder - flow_encoder = model.model.flow.encoder - script = get_optimized_script(flow_encoder) - script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) - script = get_optimized_script(flow_encoder.half()) - script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) + # 3. export flow encoder + flow_encoder = model.model.flow.encoder + script = get_optimized_script(flow_encoder) + script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) + script = get_optimized_script(flow_encoder.half()) + script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) + logging.info('successfully export flow_encoder') + else: + # 3. export flow encoder + flow_encoder = model.model.flow.encoder + script = get_optimized_script(flow_encoder, ['forward_chunk']) + script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) + script = get_optimized_script(flow_encoder.half(), ['forward_chunk']) + script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) + logging.info('successfully export flow_encoder') if __name__ == '__main__': diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py index 9ddd358..310f2c1 100644 --- a/cosyvoice/bin/export_onnx.py +++ b/cosyvoice/bin/export_onnx.py @@ -28,6 +28,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import logging def get_dummy_input(batch_size, seq_len, out_channels, device): @@ -51,6 +52,7 @@ def get_args(): return args +@torch.no_grad() def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, @@ -60,56 +62,132 @@ def main(): model = CosyVoice(args.model_dir) except Exception: try: - model = CosyVoice2(args.model_dir) + # NOTE set use_flow_cache=True when export jit for cache inference + model = CosyVoice2(args.model_dir, use_flow_cache=True) except Exception: raise TypeError('no valid model_type!') - # 1. export flow decoder estimator - estimator = model.model.flow.decoder.estimator + if not isinstance(model, CosyVoice2): + # 1. export flow decoder estimator + estimator = model.model.flow.decoder.estimator + estimator.eval() - device = model.model.device - batch_size, seq_len = 2, 256 - out_channels = model.model.flow.decoder.estimator.out_channels - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) - torch.onnx.export( - estimator, - (x, mask, mu, t, spks, cond), - '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], - output_names=['estimator_out'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'estimator_out': {2: 'seq_len'}, - } - ) + device = model.model.device + batch_size, seq_len = 2, 256 + out_channels = model.model.flow.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'estimator_out': {2: 'seq_len'}, + } + ) - # 2. test computation consistency - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] - estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - sess_options=option, providers=providers) + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + sess_options=option, providers=providers) - for _ in tqdm(range(10)): - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) - output_pytorch = estimator(x, mask, mu, t, spks, cond) - ort_inputs = { - 'x': x.cpu().numpy(), - 'mask': mask.cpu().numpy(), - 'mu': mu.cpu().numpy(), - 't': t.cpu().numpy(), - 'spks': spks.cpu().numpy(), - 'cond': cond.cpu().numpy() - } - output_onnx = estimator_onnx.run(None, ort_inputs)[0] - torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + logging.info('successfully export estimator') + else: + # 1. export flow decoder estimator + estimator = model.model.flow.decoder.estimator + estimator.forward = estimator.forward_chunk + estimator.eval() + + device = model.model.device + batch_size, seq_len = 2, 256 + out_channels = model.model.flow.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + cache = model.model.init_flow_cache()['decoder_cache'] + cache.pop('offset') + cache = {k: v[0] for k, v in cache.items()} + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond, + cache['down_blocks_conv_cache'], + cache['down_blocks_kv_cache'], + cache['mid_blocks_conv_cache'], + cache['mid_blocks_kv_cache'], + cache['up_blocks_conv_cache'], + cache['up_blocks_kv_cache'], + cache['final_blocks_conv_cache']), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache', + 'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'], + output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out', + 'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'down_blocks_kv_cache': {3: 'cache_in_len'}, + 'mid_blocks_kv_cache': {3: 'cache_in_len'}, + 'up_blocks_kv_cache': {3: 'cache_in_len'}, + 'estimator_out': {2: 'seq_len'}, + 'down_blocks_kv_cache_out': {3: 'cache_out_len'}, + 'mid_blocks_kv_cache_out': {3: 'cache_out_len'}, + 'up_blocks_kv_cache_out': {3: 'cache_out_len'}, + } + ) + + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + sess_options=option, providers=providers) + + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) + cache = model.model.init_flow_cache()['decoder_cache'] + cache.pop('offset') + cache = {k: v[0] for k, v in cache.items()} + output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()}) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy(), + } + output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}}) + for i, j in zip(output_pytorch, output_onnx): + torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4) + logging.info('successfully export estimator') if __name__ == "__main__": diff --git a/cosyvoice/bin/export_trt.sh b/cosyvoice/bin/export_trt.sh deleted file mode 100644 index 808d02a..0000000 --- a/cosyvoice/bin/export_trt.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -# Copyright 2024 Alibaba Inc. All Rights Reserved. -# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability -# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz -TRT_DIR= -MODEL_DIR= - -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64 -$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw -$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py index 2cb831a..f6ec39f 100644 --- a/cosyvoice/bin/inference.py +++ b/cosyvoice/bin/inference.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader import torchaudio from hyperpyyaml import load_hyperpyyaml from tqdm import tqdm -from cosyvoice.cli.model import CosyVoiceModel +from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.dataset.dataset import Dataset @@ -33,6 +33,7 @@ def get_args(): parser.add_argument('--prompt_data', required=True, help='prompt data file') parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') parser.add_argument('--tts_text', required=True, help='tts input file') + parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path') parser.add_argument('--llm_model', required=True, help='llm model file') parser.add_argument('--flow_model', required=True, help='flow model file') parser.add_argument('--hifigan_model', required=True, help='hifigan model file') @@ -59,16 +60,25 @@ def main(): # Init cosyvoice models from configs use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') - with open(args.config, 'r') as f: - configs = load_hyperpyyaml(f) + try: + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path}) + model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16=False) + except Exception: + try: + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f) + model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16=False) + except Exception: + raise TypeError('no valid model_type!') - model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) model.load(args.llm_model, args.flow_model, args.hifigan_model) test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + sample_rate = configs['sample_rate'] del configs os.makedirs(args.result_dir, exist_ok=True) fn = os.path.join(args.result_dir, 'wav.scp') @@ -104,7 +114,7 @@ def main(): tts_speeches = torch.concat(tts_speeches, dim=1) tts_key = '{}_{}'.format(utts[0], tts_index[0]) tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) - torchaudio.save(tts_fn, tts_speeches, sample_rate=22050) + torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile') f.write('{} {}\n'.format(tts_key, tts_fn)) f.flush() f.close() diff --git a/cosyvoice/bin/train.py b/cosyvoice/bin/train.py index 3b4710e..b214e6a 100644 --- a/cosyvoice/bin/train.py +++ b/cosyvoice/bin/train.py @@ -46,6 +46,7 @@ def get_args(): parser.add_argument('--config', required=True, help='config file') parser.add_argument('--train_data', required=True, help='train data file') parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path') parser.add_argument('--checkpoint', help='checkpoint model') parser.add_argument('--model_dir', required=True, help='save model dir') parser.add_argument('--tensorboard_dir', @@ -97,8 +98,12 @@ def main(): override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} if gan is True: override_dict.pop('hift') - with open(args.config, 'r') as f: - configs = load_hyperpyyaml(f, overrides=override_dict) + try: + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path}) + except Exception: + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict) if gan is True: configs['train_conf'] = configs['train_conf_gan'] configs['train_conf'].update(vars(args)) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index e2d62e2..bcff6ab 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -32,7 +32,10 @@ class CosyVoice: self.fp16 = fp16 if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) - with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: + hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError('{} not found!'.format(hyper_yaml_path)) + with open(hyper_yaml_path, 'r') as f: configs = load_hyperpyyaml(f) assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir) self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], @@ -126,13 +129,16 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) - with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: + hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError('{} not found!'.format(hyper_yaml_path)) + with open(hyper_yaml_path, 'r') as f: configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir) self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], @@ -145,9 +151,9 @@ class CosyVoice2(CosyVoice): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache) self.model.load('{}/llm.pt'.format(model_dir), - '{}/flow.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9ebf8cb..9a50991 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -36,16 +36,12 @@ class CosyVoiceModel: self.flow = flow self.hift = hift self.fp16 = fp16 - self.llm.fp16 = fp16 - self.flow.fp16 = fp16 if self.fp16 is True: self.llm.half() self.flow.half() self.token_min_hop_len = 2 * self.flow.input_frame_rate self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_overlap_len = 20 - # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability - self.flow.decoder.estimator.static_chunk_size = 0 # mel fade in out self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256) self.mel_window = np.hamming(2 * self.mel_overlap_len) @@ -87,19 +83,25 @@ class CosyVoiceModel: def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' if not os.path.exists(flow_decoder_estimator_model): - convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16) + convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) if os.path.getsize(flow_decoder_estimator_model) == 0: raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model)) del self.flow.decoder.estimator import tensorrt as trt with open(flow_decoder_estimator_model, 'rb') as f: self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) - if self.flow.decoder.estimator_engine is None: - raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model)) + assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() + def get_trt_kwargs(self): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] + opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)] + max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] + input_names = ["x", "mask", "mu", "cond"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): - with self.llm_context: + with self.llm_context, torch.cuda.amp.autocast(self.fp16): if isinstance(text, Generator): assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' for i in self.llm.inference_bistream(text=text, @@ -121,15 +123,15 @@ class CosyVoiceModel: self.llm_end_dict[uuid] = True def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): - tts_mel, flow_cache = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - flow_cache=self.flow_cache_dict[uuid]) - self.flow_cache_dict[uuid] = flow_cache + with torch.cuda.amp.autocast(self.fp16): + tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + flow_cache=self.flow_cache_dict[uuid]) # mel overlap fade in out if self.mel_overlap_dict[uuid].shape[2] != 0: @@ -276,6 +278,7 @@ class CosyVoiceModel: self.llm_end_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) + self.flow_cache_dict.pop(this_uuid) torch.cuda.empty_cache() @@ -285,49 +288,88 @@ class CosyVoice2Model(CosyVoiceModel): llm: torch.nn.Module, flow: torch.nn.Module, hift: torch.nn.Module, - fp16: bool): + fp16: bool, + use_flow_cache: bool): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 - self.llm.fp16 = fp16 - self.flow.fp16 = fp16 + self.use_flow_cache = use_flow_cache if self.fp16 is True: self.llm.half() self.flow.half() - self.token_hop_len = 2 * self.flow.input_frame_rate - # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache - self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate - self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio + # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml + self.token_hop_len = 25 + self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len # hift cache self.mel_cache_len = 8 self.source_cache_len = int(self.mel_cache_len * 480) # speech fade in out self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related - self.stream_scale_factor = 1 self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} + self.flow_cache_dict = {} self.hift_cache_dict = {} + def init_flow_cache(self): + encoder_cache = {'offset': 0, + 'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device), + 'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device), + 'upsample_offset': 0, + 'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device), + 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)} + decoder_cache = {'offset': 0, + 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device), + 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device), + 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device), + 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device), + 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device), + 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device), + 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)} + if self.fp16 is True: + for cache in [encoder_cache, decoder_cache]: + for k, v in cache.items(): + if isinstance(v, torch.Tensor): + cache[k] = v.half() + cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache} + return cache + + def trim_flow_cache(self, cache): + if self.flow_decoder_required_cache_size > 0: + cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] + cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] + cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] + return cache + def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): - tts_mel, _ = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - finalize=finalize) - tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] + def get_trt_kwargs(self): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)] + opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)] + max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)] + input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache'] + assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape" + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): + with torch.cuda.amp.autocast(self.fp16): + tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + cache=self.flow_cache_dict[uuid], + finalize=finalize) + self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid]) # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] @@ -359,27 +401,34 @@ class CosyVoice2Model(CosyVoiceModel): prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) + # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size + if self.use_flow_cache is True: + flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:] + prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:] with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None + self.flow_cache_dict[this_uuid] = self.init_flow_cache() p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p.start() if stream is True: - token_offset = 0 while True: time.sleep(0.1) - if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) + if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, uuid=this_uuid, - token_offset=token_offset, finalize=False) - token_offset += self.token_hop_len + # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk + flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device) + prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device) yield {'tts_speech': this_tts_speech.cpu()} - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len: + with self.lock: + self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:] + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len: break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None @@ -389,7 +438,6 @@ class CosyVoice2Model(CosyVoiceModel): prompt_feat=prompt_speech_feat, embedding=flow_embedding, uuid=this_uuid, - token_offset=token_offset, finalize=True) yield {'tts_speech': this_tts_speech.cpu()} else: @@ -401,11 +449,12 @@ class CosyVoice2Model(CosyVoiceModel): prompt_feat=prompt_speech_feat, embedding=flow_embedding, uuid=this_uuid, - token_offset=0, finalize=True, speed=speed) yield {'tts_speech': this_tts_speech.cpu()} with self.lock: self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) + self.hift_cache_dict.pop(this_uuid) + self.flow_cache_dict.pop(this_uuid) torch.cuda.empty_cache() diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 67434c7..8424ada 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -196,8 +196,8 @@ def compute_f0(data, sample_rate, hop_size, mode='train'): assert 'text_token' in sample waveform = sample['speech'] _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) - if sum(_f0 != 0) < 5: # this happens when the algorithm fails - _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio + if sum(_f0 != 0) < 5: # this happens when the algorithm fails + _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate) f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1) sample['pitch_feat'] = f0 diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 420a1bf..261cf09 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -11,14 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple, Optional, Dict, Any import torch import torch.nn as nn import torch.nn.functional as F from einops import pack, rearrange, repeat +from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate from cosyvoice.utils.common import mask_to_bias from cosyvoice.utils.mask import add_optional_chunk_mask from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D -from matcha.models.components.transformer import BasicTransformerBlock +from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph class Transpose(torch.nn.Module): @@ -27,34 +29,11 @@ class Transpose(torch.nn.Module): self.dim0 = dim0 self.dim1 = dim1 - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: x = torch.transpose(x, self.dim0, self.dim1) return x -class CausalBlock1D(Block1D): - def __init__(self, dim: int, dim_out: int): - super(CausalBlock1D, self).__init__(dim, dim_out) - self.block = torch.nn.Sequential( - CausalConv1d(dim, dim_out, 3), - Transpose(1, 2), - nn.LayerNorm(dim_out), - Transpose(1, 2), - nn.Mish(), - ) - - def forward(self, x: torch.Tensor, mask: torch.Tensor): - output = self.block(x * mask) - return output * mask - - -class CausalResnetBlock1D(ResnetBlock1D): - def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): - super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) - self.block1 = CausalBlock1D(dim, dim_out) - self.block2 = CausalBlock1D(dim_out, dim_out) - - class CausalConv1d(torch.nn.Conv1d): def __init__( self, @@ -76,12 +55,339 @@ class CausalConv1d(torch.nn.Conv1d): padding_mode=padding_mode, device=device, dtype=dtype) assert stride == 1 - self.causal_padding = (kernel_size - 1, 0) + self.causal_padding = kernel_size - 1 - def forward(self, x: torch.Tensor): - x = F.pad(x, self.causal_padding) + def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: + if cache.size(2) == 0: + x = F.pad(x, (self.causal_padding, 0), value=0.0) + else: + assert cache.size(2) == self.causal_padding + x = torch.concat([cache, x], dim=2) + cache = x[:, :, -self.causal_padding:] x = super(CausalConv1d, self).forward(x) - return x + return x, cache + + +class CausalBlock1D(Block1D): + def __init__(self, dim: int, dim_out: int): + super(CausalBlock1D, self).__init__(dim, dim_out) + self.block = torch.nn.Sequential( + CausalConv1d(dim, dim_out, 3), + Transpose(1, 2), + nn.LayerNorm(dim_out), + Transpose(1, 2), + nn.Mish(), + ) + + def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: + output, cache = self.block[0](x * mask, cache) + for i in range(1, len(self.block)): + output = self.block[i](output) + return output * mask, cache + + +class CausalResnetBlock1D(ResnetBlock1D): + def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): + super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) + self.block1 = CausalBlock1D(dim, dim_out) + self.block2 = CausalBlock1D(dim_out, dim_out) + + def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor, + block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0) + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + h, block1_cache = self.block1(x, mask, block1_cache) + h += self.mlp(time_emb).unsqueeze(-1) + h, block2_cache = self.block2(h, mask, block2_cache) + output = h + self.res_conv(x * mask) + return output, block1_cache, block2_cache + + +class CausalAttnProcessor2_0(AttnProcessor2_0): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + super(CausalAttnProcessor2_0, self).__init__() + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + *args, + **kwargs, + ) -> Tuple[torch.FloatTensor, torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \ + `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_cache = attn.to_k(encoder_hidden_states) + value_cache = attn.to_v(encoder_hidden_states) + # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2) + if cache.size(0) != 0: + key = torch.concat([cache[:, :, :, 0], key_cache], dim=1) + value = torch.concat([cache[:, :, :, 1], value_cache], dim=1) + else: + key, value = key_cache, value_cache + cache = torch.stack([key_cache, value_cache], dim=3) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states, cache + + +@maybe_allow_in_graph +class CausalAttention(Attention): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor2_0"] = None, + out_dim: int = None, + ): + super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, + cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups, + spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, + _from_deprecated_attn_block, processor, out_dim) + processor = CausalAttnProcessor2_0() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + **cross_attention_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cache=cache, + **cross_attention_kwargs, + ) + + +@maybe_allow_in_graph +class CausalBasicTransformerBlock(BasicTransformerBlock): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout, + cross_attention_dim, activation_fn, num_embeds_ada_norm, + attention_bias, only_cross_attention, double_self_attention, + upcast_attention, norm_elementwise_affine, norm_type, final_dropout) + self.attn1 = CausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output, cache = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + cache=cache, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \ + {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.") + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states, cache class ConditionalDecoder(nn.Module): @@ -89,7 +395,6 @@ class ConditionalDecoder(nn.Module): self, in_channels, out_channels, - causal=False, channels=(256, 256), dropout=0.05, attention_head_dim=64, @@ -106,7 +411,7 @@ class ConditionalDecoder(nn.Module): channels = tuple(channels) self.in_channels = in_channels self.out_channels = out_channels - self.causal = causal + self.time_embeddings = SinusoidalPosEmb(in_channels) time_embed_dim = channels[0] * 4 self.time_mlp = TimestepEmbedding( @@ -123,8 +428,7 @@ class ConditionalDecoder(nn.Module): input_channel = output_channel output_channel = channels[i] is_last = i == len(channels) - 1 - resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ - ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( @@ -138,16 +442,14 @@ class ConditionalDecoder(nn.Module): ] ) downsample = ( - Downsample1D(output_channel) if not is_last else - CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) for _ in range(num_mid_blocks): input_channel = channels[-1] out_channels = channels[-1] - resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ - ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ @@ -169,11 +471,7 @@ class ConditionalDecoder(nn.Module): input_channel = channels[i] * 2 output_channel = channels[i + 1] is_last = i == len(channels) - 2 - resnet = CausalResnetBlock1D( - dim=input_channel, - dim_out=output_channel, - time_emb_dim=time_embed_dim, - ) if self.causal else ResnetBlock1D( + resnet = ResnetBlock1D( dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim, @@ -193,10 +491,10 @@ class ConditionalDecoder(nn.Module): upsample = ( Upsample1D(output_channel, use_conv_transpose=True) if not is_last - else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) + else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) - self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) + self.final_block = Block1D(channels[-1], channels[-1]) self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.initialize_weights() @@ -214,7 +512,7 @@ class ConditionalDecoder(nn.Module): if m.bias is not None: nn.init.constant_(m.bias, 0) - def forward(self, x, mask, mu, t, spks=None, cond=None): + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): """Forward pass of the UNet1DConditional model. Args: @@ -249,9 +547,8 @@ class ConditionalDecoder(nn.Module): mask_down = masks[-1] x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) - attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -268,9 +565,8 @@ class ConditionalDecoder(nn.Module): for resnet, transformer_blocks in self.mid_blocks: x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) - attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -285,9 +581,8 @@ class ConditionalDecoder(nn.Module): x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) - attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -299,3 +594,309 @@ class ConditionalDecoder(nn.Module): x = self.final_block(x, mask_up) output = self.final_proj(x * mask_up) return output * mask + + +class CausalConditionalDecoder(ConditionalDecoder): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + static_chunk_size=50, + num_decoding_left_chunks=2, + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + torch.nn.Module.__init__(self) + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.static_chunk_size = static_chunk_size + self.num_decoding_left_chunks = num_decoding_left_chunks + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + CausalBasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + CausalBasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = CausalResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + CausalBasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else CausalConv1d(output_channel, output_channel, 3) + ) + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + self.final_block = CausalBlock1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x, _, _ = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + else: + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x, _ = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x, _ = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x, _, _ = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + else: + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x, _ = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x, _, _ = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + else: + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x, _ = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x, _ = upsample(x * mask_up) + x, _ = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask + + def forward_chunk(self, x, mask, mu, t, spks=None, cond=None, + down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), + mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), + up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), + final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0) + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + + down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device) + mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device) + up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device) + for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks): + mask_down = masks[-1] + x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \ + resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576]) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool() + attn_mask = mask_to_bias(attn_mask, x.dtype) + for i, transformer_block in enumerate(transformer_blocks): + x, down_blocks_kv_cache_new[index, i] = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + cache=down_blocks_kv_cache[index, i], + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:]) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks): + x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \ + resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:]) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool() + attn_mask = mask_to_bias(attn_mask, x.dtype) + for i, transformer_block in enumerate(transformer_blocks): + x, mid_blocks_kv_cache_new[index, i] = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + cache=mid_blocks_kv_cache[index, i] + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks): + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \ + resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768]) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool() + attn_mask = mask_to_bias(attn_mask, x.dtype) + for i, transformer_block in enumerate(transformer_blocks): + x, up_blocks_kv_cache_new[index, i] = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + cache=up_blocks_kv_cache[index, i] + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:]) + x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache) + output = self.final_proj(x * mask_up) + return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \ + up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 72bb34c..9c642ee 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -91,6 +91,7 @@ class MaskedDiffWithXvec(torch.nn.Module): conds = conds.transpose(1, 2) mask = (~make_pad_mask(feat_len)).to(h) + # NOTE this is unnecessary, feat/h already same shape feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) loss, _ = self.decoder.compute_loss( feat.transpose(1, 2).contiguous(), @@ -111,16 +112,12 @@ class MaskedDiffWithXvec(torch.nn.Module): prompt_feat_len, embedding, flow_cache): - if self.fp16 is True: - prompt_feat = prompt_feat.half() - embedding = embedding.half() - assert token.shape[0] == 1 # xvec projection embedding = F.normalize(embedding, dim=1) embedding = self.spk_embed_affine_layer(embedding) - # concat text and prompt_text + # concat speech token and prompt speech token token_len1, token_len2 = prompt_token.shape[1], token.shape[1] token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) @@ -145,7 +142,7 @@ class MaskedDiffWithXvec(torch.nn.Module): cond=conds, n_timesteps=10, prompt_len=mel_len1, - flow_cache=flow_cache + cache=flow_cache ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 @@ -190,6 +187,53 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): self.token_mel_ratio = token_mel_ratio self.pre_lookahead_len = pre_lookahead_len + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + token = batch['speech_token'].to(device) + token_len = batch['speech_token_len'].to(device) + feat = batch['speech_feat'].to(device) + feat_len = batch['speech_feat_len'].to(device) + embedding = batch['embedding'].to(device) + + # NOTE unified training, static_chunk_size > 0 or = 0 + streaming = True if random.random() < 0.5 else False + + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len, streaming=streaming) + h = self.encoder_proj(h) + + # get conditions + feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) + conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h) + loss, _ = self.decoder.compute_loss( + feat.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + embedding, + cond=conds, + streaming=streaming, + ) + return {'loss': loss} + @torch.inference_mode() def inference(self, token, @@ -199,11 +243,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): prompt_feat, prompt_feat_len, embedding, + cache, finalize): - if self.fp16 is True: - prompt_feat = prompt_feat.half() - embedding = embedding.half() - assert token.shape[0] == 1 # xvec projection embedding = F.normalize(embedding, dim=1) @@ -215,9 +256,17 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): token = self.input_embedding(torch.clamp(token, min=0)) * mask # text encode - h, h_lengths = self.encoder(token, token_len) - if finalize is False: - h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] + if finalize is True: + h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache']) + else: + token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:] + h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache']) + cache['encoder_cache']['offset'] = encoder_cache[0] + cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1] + cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2] + cache['encoder_cache']['upsample_offset'] = encoder_cache[3] + cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4] + cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5] mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] h = self.encoder_proj(h) @@ -227,13 +276,14 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): conds = conds.transpose(1, 2) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - feat, _ = self.decoder( + feat, cache['decoder_cache'] = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, cond=conds, - n_timesteps=10 + n_timesteps=10, + cache=cache['decoder_cache'] ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 - return feat.float(), None + return feat.float(), cache diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 6a60f6d..4039896 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -34,7 +34,7 @@ class ConditionalCFM(BASECFM): self.lock = threading.Lock() @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)): """Forward diffusion Args: @@ -54,19 +54,19 @@ class ConditionalCFM(BASECFM): """ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature - cache_size = flow_cache.shape[2] + cache_size = cache.shape[2] # fix prompt and overlap part mu and z if cache_size != 0: - z[:, :, :cache_size] = flow_cache[:, :, :, 0] - mu[:, :, :cache_size] = flow_cache[:, :, :, 1] + z[:, :, :cache_size] = cache[:, :, :, 0] + mu[:, :, :cache_size] = cache[:, :, :, 1] z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) - flow_cache = torch.stack([z_cache, mu_cache], dim=-1) + cache = torch.stack([z_cache, mu_cache], dim=-1) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache def solve_euler(self, x, t_span, mu, mask, spks, cond): """ @@ -123,7 +123,7 @@ class ConditionalCFM(BASECFM): def forward_estimator(self, x, mask, mu, t, spks, cond): if isinstance(self.estimator, torch.nn.Module): - return self.estimator.forward(x, mask, mu, t, spks, cond) + return self.estimator(x, mask, mu, t, spks, cond) else: with self.lock: self.estimator.set_input_shape('x', (2, 80, x.size(2))) @@ -133,16 +133,16 @@ class ConditionalCFM(BASECFM): self.estimator.set_input_shape('spks', (2, 80)) self.estimator.set_input_shape('cond', (2, 80, x.size(2))) # run trt engine - self.estimator.execute_v2([x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()]) + assert self.estimator.execute_v2([x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()]) is True return x - def compute_loss(self, x1, mask, mu, spks=None, cond=None): + def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): """Computes diffusion loss Args: @@ -179,7 +179,7 @@ class ConditionalCFM(BASECFM): spks = spks * cfg_mask.view(-1, 1) cond = cond * cfg_mask.view(-1, 1, 1) - pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) + pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming) loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) return loss, y @@ -190,7 +190,7 @@ class CausalConditionalCFM(ConditionalCFM): self.rand_noise = torch.randn([1, 80, 50 * 300]) @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}): """Forward diffusion Args: @@ -209,9 +209,136 @@ class CausalConditionalCFM(ConditionalCFM): shape: (batch_size, n_feats, mel_timesteps) """ - z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature + offset = cache.pop('offset') + z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature + z = z[:, :, offset:] + offset += mu.size(2) # fix prompt and overlap part mu and z t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None + mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache) + cache['offset'] = offset + return mel, cache + + def solve_euler(self, x, t_span, mu, mask, spks, cond, cache): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + # estimator cache for each step + down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x) + mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x) + up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x) + + # Do not use concat, it may cause memory format changed and trt infer with wrong results! + x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) + mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + t_in = torch.zeros([2], device=x.device, dtype=x.dtype) + spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) + cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + for step in range(1, len(t_span)): + # Classifier-Free Guidance inference introduced in VoiceBox + x_in[:] = x + mask_in[:] = mask + mu_in[0] = mu + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond + cache_step = {k: v[step - 1] for k, v in cache.items()} + dphi_dt, cache_step = self.forward_estimator( + x_in, mask_in, + mu_in, t_in, + spks_in, + cond_in, + cache_step + ) + cache['down_blocks_conv_cache'][step - 1] = cache_step[0] + down_blocks_kv_cache_new[step - 1] = cache_step[1] + cache['mid_blocks_conv_cache'][step - 1] = cache_step[2] + mid_blocks_kv_cache_new[step - 1] = cache_step[3] + cache['up_blocks_conv_cache'][step - 1] = cache_step[4] + up_blocks_kv_cache_new[step - 1] = cache_step[5] + cache['final_blocks_conv_cache'][step - 1] = cache_step[6] + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4) + cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4) + cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4) + return sol[-1].float(), cache + + def forward_estimator(self, x, mask, mu, t, spks, cond, cache): + if isinstance(self.estimator, torch.nn.Module): + x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache) + cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7) + else: + with self.lock: + self.estimator.set_input_shape('x', (2, 80, x.size(2))) + self.estimator.set_input_shape('mask', (2, 1, x.size(2))) + self.estimator.set_input_shape('mu', (2, 80, x.size(2))) + self.estimator.set_input_shape('t', (2,)) + self.estimator.set_input_shape('spks', (2, 80)) + self.estimator.set_input_shape('cond', (2, 80, x.size(2))) + self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape) + self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape) + self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape) + self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape) + self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape) + self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape) + self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape) + # run trt engine + down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) + mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x) + up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) + assert self.estimator.execute_v2([x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + cache['down_blocks_conv_cache'].contiguous().data_ptr(), + cache['down_blocks_kv_cache'].contiguous().data_ptr(), + cache['mid_blocks_conv_cache'].contiguous().data_ptr(), + cache['mid_blocks_kv_cache'].contiguous().data_ptr(), + cache['up_blocks_conv_cache'].contiguous().data_ptr(), + cache['up_blocks_kv_cache'].contiguous().data_ptr(), + cache['final_blocks_conv_cache'].contiguous().data_ptr(), + x.data_ptr(), + cache['down_blocks_conv_cache'].data_ptr(), + down_blocks_kv_cache_out.data_ptr(), + cache['mid_blocks_conv_cache'].data_ptr(), + mid_blocks_kv_cache_out.data_ptr(), + cache['up_blocks_conv_cache'].data_ptr(), + up_blocks_kv_cache_out.data_ptr(), + cache['final_blocks_conv_cache'].data_ptr()]) is True + cache = (cache['down_blocks_conv_cache'], + down_blocks_kv_cache_out, + cache['mid_blocks_conv_cache'], + mid_blocks_kv_cache_out, + cache['up_blocks_conv_cache'], + up_blocks_kv_cache_out, + cache['final_blocks_conv_cache']) + return x, cache diff --git a/cosyvoice/flow/length_regulator.py b/cosyvoice/flow/length_regulator.py index 2cae42f..e1b6c1b 100644 --- a/cosyvoice/flow/length_regulator.py +++ b/cosyvoice/flow/length_regulator.py @@ -51,6 +51,7 @@ class InterpolateRegulator(nn.Module): def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel + # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py # x in (B, T, D) if x2.shape[1] > 40: x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') diff --git a/cosyvoice/hifigan/discriminator.py b/cosyvoice/hifigan/discriminator.py index 1a4dcc8..bb8e85f 100644 --- a/cosyvoice/hifigan/discriminator.py +++ b/cosyvoice/hifigan/discriminator.py @@ -1,10 +1,16 @@ import torch import torch.nn as nn -from torch.nn.utils.parametrizations import weight_norm +import torch.nn.functional as F +try: + from torch.nn.utils.parametrizations import weight_norm, spectral_norm +except ImportError: + from torch.nn.utils import weight_norm, spectral_norm from typing import List, Optional, Tuple from einops import rearrange from torchaudio.transforms import Spectrogram +LRELU_SLOPE = 0.1 + class MultipleDiscriminator(nn.Module): def __init__( @@ -138,3 +144,87 @@ class DiscriminatorR(nn.Module): x += h return x, fmap + + +class MultiResSpecDiscriminator(torch.nn.Module): + + def __init__(self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window"): + + super(MultiResSpecDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window), + SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window), + SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for _, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + """ + x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True) + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.abs(x_stft).transpose(2, 1) + + +class SpecDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False): + super(SpecDiscriminator, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.discriminators = nn.ModuleList([ + norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), + ]) + + self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1)) + + def forward(self, y): + + fmap = [] + y = y.squeeze(1) + y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device)) + y = y.unsqueeze(1) + for _, d in enumerate(self.discriminators): + y = d(y) + y = F.leaky_relu(y, LRELU_SLOPE) + fmap.append(y) + + y = self.out(y) + fmap.append(y) + + return torch.flatten(y, 1, -1), fmap diff --git a/cosyvoice/hifigan/f0_predictor.py b/cosyvoice/hifigan/f0_predictor.py index 172c5f5..5797c31 100644 --- a/cosyvoice/hifigan/f0_predictor.py +++ b/cosyvoice/hifigan/f0_predictor.py @@ -13,7 +13,10 @@ # limitations under the License. import torch import torch.nn as nn -from torch.nn.utils.parametrizations import weight_norm +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm class ConvRNNF0Predictor(nn.Module): diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index c47bf05..50d7f99 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -23,7 +23,10 @@ import torch.nn.functional as F from torch.nn import Conv1d from torch.nn import ConvTranspose1d from torch.nn.utils import remove_weight_norm -from torch.nn.utils.parametrizations import weight_norm +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm from torch.distributions.uniform import Uniform from cosyvoice.transformer.activation import Snake diff --git a/cosyvoice/hifigan/hifigan.py b/cosyvoice/hifigan/hifigan.py index de623cc..046c2cf 100644 --- a/cosyvoice/hifigan/hifigan.py +++ b/cosyvoice/hifigan/hifigan.py @@ -41,7 +41,7 @@ class HiFiGan(nn.Module): loss_fm = feature_loss(fmap_rs, fmap_gs) loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) if self.tpr_loss_weight != 0: - loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) + loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau) else: loss_tpr = torch.zeros(1).to(device) loss_f0 = F.l1_loss(generated_f0, pitch_feat) @@ -56,7 +56,7 @@ class HiFiGan(nn.Module): with torch.no_grad(): generated_speech, generated_f0 = self.generator(batch, device) # 2. calculate discriminator outputs - y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach()) # 3. calculate discriminator losses, tpr losses [Optional] loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) if self.tpr_loss_weight != 0: diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index bbd3305..670ae69 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from typing import Dict, Optional, Callable, List, Generator import torch from torch import nn @@ -21,6 +22,7 @@ from cosyvoice.utils.common import IGNORE_ID from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss from cosyvoice.utils.common import th_accuracy from cosyvoice.utils.file_utils import logging +from cosyvoice.utils.mask import make_pad_mask class TransformerLM(torch.nn.Module): @@ -169,9 +171,6 @@ class TransformerLM(torch.nn.Module): max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, ) -> Generator[torch.Tensor, None, None]: - if self.fp16 is True: - embedding = embedding.half() - device = text.device text = torch.concat([prompt_text, text], dim=1) text_len += prompt_text_len @@ -229,6 +228,17 @@ class Qwen2Encoder(torch.nn.Module): super().__init__() self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor): + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T) + outs = self.model( + inputs_embeds=xs, + attention_mask=masks, + output_hidden_states=True, + return_dict=True, + ) + return outs.hidden_states[-1], masks.unsqueeze(1) + def forward_one_step(self, xs, masks, cache=None): input_masks = masks[:, -1, :] outs = self.model( @@ -283,6 +293,82 @@ class Qwen2LM(TransformerLM): self.sampling = sampling self.mix_ratio = mix_ratio + def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len): + lm_target, lm_input = [], [] + text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) + speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) + text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True) + speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True) + for i in range(len(text_token)): + # bistream sequence + if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]: + this_lm_target, this_lm_input = [], [] + this_lm_target.append(IGNORE_ID) + this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1)) + for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()): + this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist() + this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist() + if len(this_text_token) == self.mix_ratio[0]: + assert len(this_speech_token) == self.mix_ratio[1] + this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1) + this_lm_target += this_speech_token + this_lm_target.append(self.speech_token_size + 2) + this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]]) + this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]]) + else: + this_lm_target += [-1] * len(this_text_token) + this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist() + this_lm_target.append(self.speech_token_size) + this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:]) + this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1)) + this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:]) + this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0) + # unistream sequence + else: + this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size]) + this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i], + self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0) + lm_target.append(this_lm_target) + lm_input.append(this_lm_input) + lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) + lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID) + return lm_target, lm_input, lm_input_len + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + + # 1. encode text_token + text_token_emb = self.llm.model.model.embed_tokens(text_token) + + # 2. encode speech_token + speech_token_emb = self.speech_embedding(speech_token) + + # 3. prepare llm_input/target + lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len) + lm_target = lm_target.to(device) + + # 4. run lm forward + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target.to(device)) + acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID) + return {'loss': loss, 'acc': acc} + @torch.inference_mode() def inference( self, @@ -393,8 +479,8 @@ class Qwen2LM(TransformerLM): while True: seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), - cache=cache) + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) if next_fill_index != -1 and len(out_tokens) == next_fill_index: top_ids = self.speech_token_size + 2 diff --git a/cosyvoice/transformer/embedding.py b/cosyvoice/transformer/embedding.py index eae8c8e..ba20d71 100644 --- a/cosyvoice/transformer/embedding.py +++ b/cosyvoice/transformer/embedding.py @@ -287,8 +287,16 @@ class EspnetRelPositionalEncoding(torch.nn.Module): Returns: torch.Tensor: Corresponding encoding """ - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, - ] + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + if isinstance(offset, int): + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset, + ] + elif isinstance(offset, torch.Tensor): + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset, + ] return pos_emb diff --git a/cosyvoice/transformer/upsample_encoder.py b/cosyvoice/transformer/upsample_encoder.py index f67fb98..0d98406 100644 --- a/cosyvoice/transformer/upsample_encoder.py +++ b/cosyvoice/transformer/upsample_encoder.py @@ -56,11 +56,16 @@ class Upsample1D(nn.Module): # In this mode, first repeat interpolate, than conv with stride=1 self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) - def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): + def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") - outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) + if conv_cache.size(2) == 0: + outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) + else: + assert conv_cache.size(2) == self.stride * 2 + outputs = torch.concat([conv_cache, outputs], dim=2) + conv_cache_new = outputs[:, :, -self.stride * 2:] outputs = self.conv(outputs) - return outputs, input_lengths * self.stride + return outputs, input_lengths * self.stride, conv_cache_new class PreLookaheadLayer(nn.Module): @@ -78,22 +83,32 @@ class PreLookaheadLayer(nn.Module): kernel_size=3, stride=1, padding=0, ) - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: """ inputs: (batch_size, seq_len, channels) """ outputs = inputs.transpose(1, 2).contiguous() + context = context.transpose(1, 2).contiguous() # look ahead - outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) + if context.size(2) == 0: + outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) + else: + assert context.size(2) == self.pre_lookahead_len + outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0) outputs = F.leaky_relu(self.conv1(outputs)) # outputs - outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0) + if conv2_cache.size(2) == 0: + outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0) + else: + assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1 + outputs = torch.concat([conv2_cache, outputs], dim=2) + conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):] outputs = self.conv2(outputs) outputs = outputs.transpose(1, 2).contiguous() # residual connection outputs = outputs + inputs - return outputs + return outputs, conv2_cache_new class UpsampleConformerEncoder(torch.nn.Module): @@ -240,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module): xs_lens: torch.Tensor, decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1, + streaming: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Embed positions in tensor. @@ -270,30 +286,20 @@ class UpsampleConformerEncoder(torch.nn.Module): xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) + chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1) # lookahead + conformer encoder - xs = self.pre_lookahead_layer(xs) + xs, _ = self.pre_lookahead_layer(xs) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) # upsample + conformer encoder xs = xs.transpose(1, 2).contiguous() - xs, xs_lens = self.up_layer(xs, xs_lens) + xs, xs_lens, _ = self.up_layer(xs, xs_lens) xs = xs.transpose(1, 2).contiguous() T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) xs, pos_emb, masks = self.up_embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size * self.up_layer.stride, - num_decoding_left_chunks) + chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1) xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) if self.normalize_before: @@ -316,3 +322,100 @@ class UpsampleConformerEncoder(torch.nn.Module): for layer in self.up_encoders: xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs + + @torch.jit.export + def forward_chunk( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + offset: int = 0, + context: torch.Tensor = torch.zeros(0, 0, 0), + pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0), + encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0), + upsample_offset: int = 0, + upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0), + upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0) + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + offset += xs.size(1) + tmp_masks = torch.ones(1, + context.size(1), + device=context.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if context.size(1) != 0: + context, _, _ = self.embed(context, tmp_masks, offset) + + # lookahead + conformer encoder + xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache) + # NOTE in cache mode we do not need to call add_optional_chunk_mask + chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device) + mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device) + encoders_kv_cache_list = [] + for index, layer in enumerate(self.encoders): + xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index]) + encoders_kv_cache_list.append(encoders_kv_cache_new) + encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0) + + # upsample + xs = xs.transpose(1, 2).contiguous() + xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache) + xs = xs.transpose(1, 2).contiguous() + + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset) + upsample_offset += xs.size(1) + + # conformer encoder + chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device) + mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device) + upsample_kv_cache_list = [] + for index, layer in enumerate(self.up_encoders): + xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index]) + upsample_kv_cache_list.append(upsample_kv_cache_new) + upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0) + + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index ac7fe93..f0a450c 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -47,13 +47,8 @@ def load_wav(wav, target_sr): return speech -def convert_onnx_to_trt(trt_model, onnx_model, fp16): +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): import tensorrt as trt - _min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)] - _opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)] - _max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)] - input_names = ["x", "mask", "mu", "t", "spks", "cond"] - logging.info("Converting onnx to trt...") network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) logger = trt.Logger(trt.Logger.INFO) @@ -72,8 +67,8 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16): print(parser.get_error(error)) raise ValueError('failed to parse {}'.format(onnx_model)) # set input shapes - for i in range(len(input_names)): - profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i]) + for i in range(len(trt_kwargs['input_names'])): + profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT # set input and output data type for i in range(network.num_inputs): @@ -87,3 +82,4 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16): # save trt engine with open(trt_model, "wb") as f: f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") diff --git a/cosyvoice/utils/mask.py b/cosyvoice/utils/mask.py index c164db1..35dcd69 100644 --- a/cosyvoice/utils/mask.py +++ b/cosyvoice/utils/mask.py @@ -15,7 +15,6 @@ # limitations under the License. import torch -from cosyvoice.utils.file_utils import logging ''' def subsequent_mask( size: int, @@ -87,7 +86,7 @@ def subsequent_mask( return mask -def subsequent_chunk_mask_deprecated( +def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, @@ -125,41 +124,6 @@ def subsequent_chunk_mask_deprecated( return ret -def subsequent_chunk_mask( - size: int, - chunk_size: int, - num_left_chunks: int = -1, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """Create mask for subsequent steps (size, size) with chunk size, - this is for streaming encoder - - Args: - size (int): size of mask - chunk_size (int): size of chunk - num_left_chunks (int): number of left chunks - <0: use full chunk - >=0: use num_left_chunks - device (torch.device): "cpu" or "cuda" or torch.Tensor.device - - Returns: - torch.Tensor: mask - - Examples: - >>> subsequent_chunk_mask(4, 2) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1]] - """ - # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks - # actually this is not needed after we have inference cache implemented, will remove it later - pos_idx = torch.arange(size, device=device) - block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size - ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) - return ret - - def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, use_dynamic_chunk: bool, @@ -233,8 +197,8 @@ def add_optional_chunk_mask(xs: torch.Tensor, chunk_masks = masks assert chunk_masks.dtype == torch.bool if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: - logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') - chunk_masks[chunk_masks.sum(dim=-1)==0] = True + print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') + chunk_masks[chunk_masks.sum(dim=-1) == 0] = True return chunk_masks diff --git a/cosyvoice/utils/train_utils.py b/cosyvoice/utils/train_utils.py index 72e291a..a6a1458 100644 --- a/cosyvoice/utils/train_utils.py +++ b/cosyvoice/utils/train_utils.py @@ -286,11 +286,15 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): # optimizer.step(). if torch.isfinite(grad_norm): scaler.step(optimizer) + else: + logging.warning('get infinite grad_norm, check your code/data if it appears frequently') scaler.update() else: grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) if torch.isfinite(grad_norm): optimizer.step() + else: + logging.warning('get infinite grad_norm, check your code/data if it appears frequently') optimizer.zero_grad() scheduler.step() info_dict["lr"] = optimizer.param_groups[0]['lr'] @@ -336,7 +340,7 @@ def log_per_save(writer, info_dict): rank = int(os.environ.get('RANK', 0)) logging.info( 'Epoch {} Step {} CV info lr {} {} rank {}'.format( - epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()]))) + epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()]))) if writer is not None: for k in ['epoch', 'lr']: diff --git a/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml b/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml index 435355f..4feb14c 100644 --- a/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml +++ b/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml @@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan generator: !ref discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator - mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator + mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator mel_spec_transform: [ !ref ] diff --git a/examples/libritts/cosyvoice/conf/cosyvoice.yaml b/examples/libritts/cosyvoice/conf/cosyvoice.yaml index 9286f79..c421b4f 100644 --- a/examples/libritts/cosyvoice/conf/cosyvoice.yaml +++ b/examples/libritts/cosyvoice/conf/cosyvoice.yaml @@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan generator: !ref discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator - mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator + mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator mel_spec_transform: [ !ref ] diff --git a/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml b/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml new file mode 100644 index 0000000..d6bdeb6 --- /dev/null +++ b/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml @@ -0,0 +1,233 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1986] +__set_seed2: !apply:numpy.random.seed [1986] +__set_seed3: !apply:torch.manual_seed [1986] +__set_seed4: !apply:torch.cuda.manual_seed_all [1986] + +# fixed params +sample_rate: 24000 +llm_input_size: 896 +llm_output_size: 896 +spk_embed_dim: 192 +qwen_pretrain_path: '' +token_frame_rate: 25 +token_mel_ratio: 2 + +# stream related params +chunk_size: 25 # streaming inference chunk size, in token +num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size, <0 means use all left chunks + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:cosyvoice.llm.llm.Qwen2LM + llm_input_size: !ref + llm_output_size: !ref + speech_token_size: 6561 + length_normalized_loss: True + lsm_weight: 0 + mix_ratio: [5, 15] + llm: !new:cosyvoice.llm.llm.Qwen2Encoder + pretrain_path: !ref + sampling: !name:cosyvoice.utils.common.ras_sampling + top_p: 0.8 + top_k: 25 + win_size: 10 + tau_r: 0.1 + +flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec + input_size: 512 + output_size: 80 + spk_embed_dim: !ref + output_type: 'mel' + vocab_size: 6561 + input_frame_rate: !ref + only_mask_loss: True + token_mel_ratio: !ref + pre_lookahead_len: 3 + encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder + output_size: 512 + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 512 + use_cnn_module: False + macaron_style: False + static_chunk_size: !ref + decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM + in_channels: 240 + n_spks: 1 + spk_emb_dim: 80 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder + in_channels: 320 + out_channels: 80 + channels: [256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 12 + num_heads: 8 + act_fn: 'gelu' + static_chunk_size: !ref * + num_decoding_left_chunks: !ref + +hift: !new:cosyvoice.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 5, 3] + upsample_kernel_sizes: [16, 11, 7] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +# gan related module +mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1920 + num_mels: 80 + sampling_rate: !ref + hop_size: 480 + win_size: 1920 + fmin: 0 + fmax: null + center: False +hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan + generator: !ref + discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator + mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator + mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator + mel_spec_transform: [ + !ref + ] + +# processor functions +parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener +get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer + token_path: !ref + skip_special_tokens: True +allowed_special: 'all' +tokenize: !name:cosyvoice.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:cosyvoice.dataset.processor.filter + max_length: 40960 + min_length: 100 + token_max_length: 200 + token_min_length: 1 +resample: !name:cosyvoice.dataset.processor.resample + resample_rate: !ref +truncate: !name:cosyvoice.dataset.processor.truncate + truncate_length: 24480 # must be a multiplier of hop_size +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1920 + num_mels: 80 + sampling_rate: !ref + hop_size: 480 + win_size: 1920 + fmin: 0 + fmax: 8000 + center: False +compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank + feat_extractor: !ref +compute_f0: !name:cosyvoice.dataset.processor.compute_f0 + sample_rate: !ref + hop_size: 480 +parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding + normalize: True +shuffle: !name:cosyvoice.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:cosyvoice.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:cosyvoice.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 2000 +padding: !name:cosyvoice.dataset.processor.padding + use_spk_embedding: False # change to True during sft + + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] +data_pipeline_gan: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + +# llm flow train conf +train_conf: + optim: adam + optim_conf: + lr: 1e-5 # change to 1e-5 during sft + scheduler: constantlr # change to constantlr during sft + scheduler_conf: + warmup_steps: 2500 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: -1 + +# gan train conf +train_conf_gan: + optim: adam + optim_conf: + lr: 0.0002 # use small lr for gan training + scheduler: constantlr + optim_d: adam + optim_conf_d: + lr: 0.0002 # use small lr for gan training + scheduler_d: constantlr + max_epoch: 200 + grad_clip: 5 + accum_grad: 1 # in gan training, accum_grad must be 1 + log_interval: 100 + save_per_step: -1 \ No newline at end of file diff --git a/examples/libritts/cosyvoice2/conf/ds_stage2.json b/examples/libritts/cosyvoice2/conf/ds_stage2.json new file mode 100644 index 0000000..2b2de3d --- /dev/null +++ b/examples/libritts/cosyvoice2/conf/ds_stage2.json @@ -0,0 +1,42 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 5, + "fp16": { + "enabled": false, + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 256, + "hysteresis": 2, + "consecutive_hysteresis": false, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": false + }, + "zero_force_ds_cpu_optimizer": false, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": false, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients" : true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 0.001, + "weight_decay": 0.0001, + "torch_adam": true, + "adam_w_mode": true + } + } +} \ No newline at end of file diff --git a/examples/libritts/cosyvoice2/local b/examples/libritts/cosyvoice2/local new file mode 120000 index 0000000..5e847a1 --- /dev/null +++ b/examples/libritts/cosyvoice2/local @@ -0,0 +1 @@ +../cosyvoice/local \ No newline at end of file diff --git a/examples/libritts/cosyvoice2/path.sh b/examples/libritts/cosyvoice2/path.sh new file mode 100644 index 0000000..e0fa06c --- /dev/null +++ b/examples/libritts/cosyvoice2/path.sh @@ -0,0 +1,3 @@ +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH diff --git a/examples/libritts/cosyvoice2/run.sh b/examples/libritts/cosyvoice2/run.sh new file mode 100644 index 0000000..e681497 --- /dev/null +++ b/examples/libritts/cosyvoice2/run.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# Copyright 2024 Alibaba Inc. All Rights Reserved. +. ./path.sh || exit 1; + +stage=-1 +stop_stage=3 + +data_url=www.openslr.org/resources/60 +data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts +pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "Data Download" + for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + local/download_and_untar.sh ${data_dir} ${data_url} ${part} + done +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" + for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mkdir -p data/$x + python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" + for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + tools/extract_embedding.py --dir data/$x \ + --onnx_path $pretrained_model_dir/campplus.onnx + done +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" + for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + tools/extract_speech_token.py --dir data/$x \ + --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx + done +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" + for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mkdir -p data/$x/parquet + tools/make_parquet_list.py --num_utts_per_parquet 1000 \ + --num_processes 10 \ + --src_dir data/$x \ + --des_dir data/$x/parquet + done +fi + +# inference +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "Run inference. Please make sure utt in tts_text is in prompt_data" + # TODO consider remove bin/inference.py, or use similar initilization method as in readme + for mode in sft zero_shot; do + python cosyvoice/bin/inference.py --mode $mode \ + --gpu 0 \ + --config conf/cosyvoice2.yaml \ + --prompt_data data/test-clean/parquet/data.list \ + --prompt_utt2data data/test-clean/parquet/utt2data.list \ + --tts_text `pwd`/tts_text.json \ + --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ + --llm_model $pretrained_model_dir/llm.pt \ + --flow_model $pretrained_model_dir/flow.pt \ + --hifigan_model $pretrained_model_dir/hift.pt \ + --result_dir `pwd`/exp/cosyvoice/test-clean/$mode + done +fi + +# train llm +export CUDA_VISIBLE_DEVICES="0,1,2,3" +num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +job_id=1986 +dist_backend="nccl" +num_workers=2 +prefetch=100 +train_engine=torch_ddp +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml" + if [ $train_engine == 'deepspeed' ]; then + echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" + fi + cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list + cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list + # NOTE will update llm/hift training later + for model in llm flow; do + torchrun --nnodes=1 --nproc_per_node=$num_gpus \ + --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ + cosyvoice/bin/train.py \ + --train_engine $train_engine \ + --config conf/cosyvoice2.yaml \ + --train_data data/train.data.list \ + --cv_data data/dev.data.list \ + --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ + --model $model \ + --checkpoint $pretrained_model_dir/$model.pt \ + --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \ + --tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \ + --ddp.dist_backend $dist_backend \ + --num_workers ${num_workers} \ + --prefetch ${prefetch} \ + --pin_memory \ + --use_amp \ + --deepspeed_config ./conf/ds_stage2.json \ + --deepspeed.save_states model+optimizer + done +fi + +# average model +average_num=5 +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + for model in llm flow hifigan; do + decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt + echo "do model average and final checkpoint is $decode_checkpoint" + python cosyvoice/bin/average_model.py \ + --dst_model $decode_checkpoint \ + --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ + --num ${average_num} \ + --val_best + done +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" + python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir + python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir +fi \ No newline at end of file diff --git a/examples/libritts/cosyvoice2/tts_text.json b/examples/libritts/cosyvoice2/tts_text.json new file mode 100644 index 0000000..9f3e8d9 --- /dev/null +++ b/examples/libritts/cosyvoice2/tts_text.json @@ -0,0 +1,5 @@ +{ + "1089_134686_000002_000000": [ + "hello, my name is Jack. What is your name?" + ] +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1998c59..4166dac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ inflect==7.3.1 librosa==0.10.2 lightning==2.2.4 matplotlib==3.7.5 -modelscope==1.15.0 +modelscope==1.20.0 networkx==3.1 omegaconf==2.3.0 onnx==1.16.0 @@ -21,6 +21,7 @@ onnxruntime-gpu==1.18.0; sys_platform == 'linux' onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32' openai-whisper==20231117 protobuf==4.25 +pyarrow==18.1.0 pydantic==2.7.0 pyworld==0.3.4 rich==13.7.1