diff --git a/cosyvoice/bin/export_jit.py b/cosyvoice/bin/export_jit.py index ddd486e..99b203f 100644 --- a/cosyvoice/bin/export_jit.py +++ b/cosyvoice/bin/export_jit.py @@ -24,6 +24,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import logging def get_args(): @@ -71,6 +72,7 @@ def main(): script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir)) script = get_optimized_script(llm_text_encoder.half()) script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) + logging.info('successfully export llm_text_encoder') # 2. export llm llm llm_llm = model.model.llm.llm @@ -78,14 +80,23 @@ def main(): script.save('{}/llm.llm.fp32.zip'.format(args.model_dir)) script = get_optimized_script(llm_llm.half(), ['forward_chunk']) script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) + logging.info('successfully export llm_llm') - # 3. export flow encoder - flow_encoder = model.model.flow.encoder - script = get_optimized_script(flow_encoder) - script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) - script = get_optimized_script(flow_encoder.half()) - script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) - + # 3. export flow encoder + flow_encoder = model.model.flow.encoder + script = get_optimized_script(flow_encoder) + script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) + script = get_optimized_script(flow_encoder.half()) + script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) + logging.info('successfully export flow_encoder') + else: + # 3. export flow encoder + flow_encoder = model.model.flow.encoder + script = get_optimized_script(flow_encoder, ['forward_chunk']) + script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) + script = get_optimized_script(flow_encoder.half(), ['forward_chunk']) + script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) + logging.info('successfully export flow_encoder') if __name__ == '__main__': main() diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py index 9ddd358..7b1d9ec 100644 --- a/cosyvoice/bin/export_onnx.py +++ b/cosyvoice/bin/export_onnx.py @@ -28,6 +28,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.utils.file_utils import logging def get_dummy_input(batch_size, seq_len, out_channels, device): @@ -51,6 +52,7 @@ def get_args(): return args +@torch.no_grad() def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, @@ -64,52 +66,125 @@ def main(): except Exception: raise TypeError('no valid model_type!') - # 1. export flow decoder estimator - estimator = model.model.flow.decoder.estimator + if not isinstance(model, CosyVoice2): + # 1. export flow decoder estimator + estimator = model.model.flow.decoder.estimator + estimator.eval() - device = model.model.device - batch_size, seq_len = 2, 256 - out_channels = model.model.flow.decoder.estimator.out_channels - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) - torch.onnx.export( - estimator, - (x, mask, mu, t, spks, cond), - '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], - output_names=['estimator_out'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'estimator_out': {2: 'seq_len'}, - } - ) + device = model.model.device + batch_size, seq_len = 2, 256 + out_channels = model.model.flow.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'estimator_out': {2: 'seq_len'}, + } + ) - # 2. test computation consistency - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] - estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - sess_options=option, providers=providers) + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + sess_options=option, providers=providers) - for _ in tqdm(range(10)): - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) - output_pytorch = estimator(x, mask, mu, t, spks, cond) - ort_inputs = { - 'x': x.cpu().numpy(), - 'mask': mask.cpu().numpy(), - 'mu': mu.cpu().numpy(), - 't': t.cpu().numpy(), - 'spks': spks.cpu().numpy(), - 'cond': cond.cpu().numpy() - } - output_onnx = estimator_onnx.run(None, ort_inputs)[0] - torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + logging.info('successfully export estimator') + else: + # 1. export flow decoder estimator + estimator = model.model.flow.decoder.estimator + estimator.forward = estimator.forward_chunk + estimator.eval() + + device = model.model.device + batch_size, seq_len = 2, 256 + out_channels = model.model.flow.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + cache = model.model.init_flow_cache()['decoder_cache'] + cache.pop('offset') + cache = {k: v[0] for k, v in cache.items()} + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond, + cache['down_blocks_conv_cache'], + cache['down_blocks_kv_cache'], + cache['mid_blocks_conv_cache'], + cache['mid_blocks_kv_cache'], + cache['up_blocks_conv_cache'], + cache['up_blocks_kv_cache'], + cache['final_blocks_conv_cache']), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache', 'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'], + output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out', 'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'down_blocks_kv_cache': {3: 'seq_len'}, + 'mid_blocks_kv_cache': {3: 'seq_len'}, + 'up_blocks_kv_cache': {3: 'seq_len'}, + 'estimator_out': {2: 'seq_len'}, + 'down_blocks_kv_cache_out': {3: 'seq_len'}, + 'mid_blocks_kv_cache_out': {3: 'seq_len'}, + 'up_blocks_kv_cache_out': {3: 'seq_len'}, + } + ) + + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + sess_options=option, providers=providers) + + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) + cache = model.model.init_flow_cache()['decoder_cache'] + cache.pop('offset') + cache = {k: v[0] for k, v in cache.items()} + output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()}) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy(), + } + output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}}) + for i, j in zip(output_pytorch, output_onnx): + torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4) + logging.info('successfully export estimator') if __name__ == "__main__": diff --git a/cosyvoice/bin/export_trt.sh b/cosyvoice/bin/export_trt.sh index 808d02a..3a1a35e 100644 --- a/cosyvoice/bin/export_trt.sh +++ b/cosyvoice/bin/export_trt.sh @@ -3,8 +3,23 @@ # download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability # for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz TRT_DIR= -MODEL_DIR= - +MODEL_DIR= export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64 + +# cosyvoice export $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw + +# cosyvoice2 export with cache +$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan \ + --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4,down_blocks_kv_cache:1x4x2x0x512x2,mid_blocks_kv_cache:12x4x2x0x512x2,up_blocks_kv_cache:1x4x2x0x512x2 \ + --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193,down_blocks_kv_cache:1x4x2x193x512x2,mid_blocks_kv_cache:12x4x2x193x512x2,up_blocks_kv_cache:1x4x2x193x512x2 \ + --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800,down_blocks_kv_cache:1x4x2x200x512x2,mid_blocks_kv_cache:12x4x2x200x512x2,up_blocks_kv_cache:1x4x2x200x512x2 \ + --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw \ + --outputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw +$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 \ + --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4,down_blocks_kv_cache:1x4x2x0x512x2,mid_blocks_kv_cache:12x4x2x0x512x2,up_blocks_kv_cache:1x4x2x0x512x2 \ + --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193,down_blocks_kv_cache:1x4x2x193x512x2,mid_blocks_kv_cache:12x4x2x193x512x2,up_blocks_kv_cache:1x4x2x193x512x2 \ + --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800,down_blocks_kv_cache:1x4x2x200x512x2,mid_blocks_kv_cache:12x4x2x200x512x2,up_blocks_kv_cache:1x4x2x200x512x2 \ + --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw \ + --outputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py index 2cb831a..dd3848a 100644 --- a/cosyvoice/bin/inference.py +++ b/cosyvoice/bin/inference.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader import torchaudio from hyperpyyaml import load_hyperpyyaml from tqdm import tqdm -from cosyvoice.cli.model import CosyVoiceModel +from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.dataset.dataset import Dataset @@ -33,6 +33,7 @@ def get_args(): parser.add_argument('--prompt_data', required=True, help='prompt data file') parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') parser.add_argument('--tts_text', required=True, help='tts input file') + parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path') parser.add_argument('--llm_model', required=True, help='llm model file') parser.add_argument('--flow_model', required=True, help='flow model file') parser.add_argument('--hifigan_model', required=True, help='hifigan model file') @@ -59,10 +60,18 @@ def main(): # Init cosyvoice models from configs use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') - with open(args.config, 'r') as f: - configs = load_hyperpyyaml(f) + try: + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path}) + model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16=False) + except Exception: + try: + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f) + model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16=False) + except Exception: + raise TypeError('no valid model_type!') - model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) model.load(args.llm_model, args.flow_model, args.hifigan_model) test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, @@ -104,7 +113,7 @@ def main(): tts_speeches = torch.concat(tts_speeches, dim=1) tts_key = '{}_{}'.format(utts[0], tts_index[0]) tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) - torchaudio.save(tts_fn, tts_speeches, sample_rate=22050) + torchaudio.save(tts_fn, tts_speeches, sample_rate=configs['sample_rate'], backend='soundfile') f.write('{} {}\n'.format(tts_key, tts_fn)) f.flush() f.close() diff --git a/cosyvoice/bin/train.py b/cosyvoice/bin/train.py index 3b4710e..b214e6a 100644 --- a/cosyvoice/bin/train.py +++ b/cosyvoice/bin/train.py @@ -46,6 +46,7 @@ def get_args(): parser.add_argument('--config', required=True, help='config file') parser.add_argument('--train_data', required=True, help='train data file') parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path') parser.add_argument('--checkpoint', help='checkpoint model') parser.add_argument('--model_dir', required=True, help='save model dir') parser.add_argument('--tensorboard_dir', @@ -97,8 +98,12 @@ def main(): override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} if gan is True: override_dict.pop('hift') - with open(args.config, 'r') as f: - configs = load_hyperpyyaml(f, overrides=override_dict) + try: + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path}) + except Exception: + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict) if gan is True: configs['train_conf'] = configs['train_conf_gan'] configs['train_conf'].update(vars(args)) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index e2d62e2..a511d78 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -32,7 +32,10 @@ class CosyVoice: self.fp16 = fp16 if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) - with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: + hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError('{} not found!'.format(hyper_yaml_path)) + with open(hyper_yaml_path, 'r') as f: configs = load_hyperpyyaml(f) assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir) self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], @@ -132,7 +135,10 @@ class CosyVoice2(CosyVoice): self.fp16 = fp16 if not os.path.exists(model_dir): model_dir = snapshot_download(model_dir) - with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: + hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError('{} not found!'.format(hyper_yaml_path)) + with open(hyper_yaml_path, 'r') as f: configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir) self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9ebf8cb..5d29827 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -44,8 +44,6 @@ class CosyVoiceModel: self.token_min_hop_len = 2 * self.flow.input_frame_rate self.token_max_hop_len = 4 * self.flow.input_frame_rate self.token_overlap_len = 20 - # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability - self.flow.decoder.estimator.static_chunk_size = 0 # mel fade in out self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256) self.mel_window = np.hamming(2 * self.mel_overlap_len) @@ -121,15 +119,14 @@ class CosyVoiceModel: self.llm_end_dict[uuid] = True def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): - tts_mel, flow_cache = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - flow_cache=self.flow_cache_dict[uuid]) - self.flow_cache_dict[uuid] = flow_cache + tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + flow_cache=self.flow_cache_dict[uuid]) # mel overlap fade in out if self.mel_overlap_dict[uuid].shape[2] != 0: @@ -276,6 +273,7 @@ class CosyVoiceModel: self.llm_end_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) + self.flow_cache_dict.pop(this_uuid) torch.cuda.empty_cache() @@ -297,9 +295,8 @@ class CosyVoice2Model(CosyVoiceModel): self.llm.half() self.flow.half() self.token_hop_len = 2 * self.flow.input_frame_rate - # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache - self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate - self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio + # flow decoder required_cache_size + self.flow_decoder_required_cache_size = self.flow.decoder.estimator.num_decoding_left_chunks * self.flow.input_frame_rate * self.flow.token_mel_ratio # hift cache self.mel_cache_len = 8 self.source_cache_len = int(self.mel_cache_len * 480) @@ -312,22 +309,49 @@ class CosyVoice2Model(CosyVoiceModel): # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} + self.flow_cache_dict = {} self.hift_cache_dict = {} + def init_flow_cache(self): + encoder_cache = {'offset': 0, + 'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device), + 'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device), + 'upsample_offset': 0, + 'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device), + 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)} + decoder_cache = {'offset': 0, + 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device), + 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device), + 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device), + 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device), + 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device), + 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device), + 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)} + cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache} + return cache + + def trim_flow_cache(self, cache): + if cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size: + cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] + cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] + cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:] + return cache + def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): - tts_mel, _ = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - finalize=finalize) - tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] + def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): + tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + cache=self.flow_cache_dict[uuid], + finalize=finalize) + self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid]) # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] @@ -362,24 +386,27 @@ class CosyVoice2Model(CosyVoiceModel): with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None + self.flow_cache_dict[this_uuid] = self.init_flow_cache() p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p.start() if stream is True: - token_offset = 0 while True: time.sleep(0.1) - if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) + if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, uuid=this_uuid, - token_offset=token_offset, finalize=False) - token_offset += self.token_hop_len + # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk + flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device) + prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device) yield {'tts_speech': this_tts_speech.cpu()} - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len: + with self.lock: + self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:] + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len: break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None @@ -389,7 +416,6 @@ class CosyVoice2Model(CosyVoiceModel): prompt_feat=prompt_speech_feat, embedding=flow_embedding, uuid=this_uuid, - token_offset=token_offset, finalize=True) yield {'tts_speech': this_tts_speech.cpu()} else: @@ -401,11 +427,12 @@ class CosyVoice2Model(CosyVoiceModel): prompt_feat=prompt_speech_feat, embedding=flow_embedding, uuid=this_uuid, - token_offset=0, finalize=True, speed=speed) yield {'tts_speech': this_tts_speech.cpu()} with self.lock: self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) + self.hift_cache_dict.pop(this_uuid) + self.flow_cache_dict.pop(this_uuid) torch.cuda.empty_cache() diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 420a1bf..865dedc 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -11,14 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple, Optional, Dict, Any import torch import torch.nn as nn import torch.nn.functional as F from einops import pack, rearrange, repeat +from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate from cosyvoice.utils.common import mask_to_bias from cosyvoice.utils.mask import add_optional_chunk_mask from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D -from matcha.models.components.transformer import BasicTransformerBlock +from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph class Transpose(torch.nn.Module): @@ -27,34 +29,11 @@ class Transpose(torch.nn.Module): self.dim0 = dim0 self.dim1 = dim1 - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: x = torch.transpose(x, self.dim0, self.dim1) return x -class CausalBlock1D(Block1D): - def __init__(self, dim: int, dim_out: int): - super(CausalBlock1D, self).__init__(dim, dim_out) - self.block = torch.nn.Sequential( - CausalConv1d(dim, dim_out, 3), - Transpose(1, 2), - nn.LayerNorm(dim_out), - Transpose(1, 2), - nn.Mish(), - ) - - def forward(self, x: torch.Tensor, mask: torch.Tensor): - output = self.block(x * mask) - return output * mask - - -class CausalResnetBlock1D(ResnetBlock1D): - def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): - super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) - self.block1 = CausalBlock1D(dim, dim_out) - self.block2 = CausalBlock1D(dim_out, dim_out) - - class CausalConv1d(torch.nn.Conv1d): def __init__( self, @@ -76,12 +55,332 @@ class CausalConv1d(torch.nn.Conv1d): padding_mode=padding_mode, device=device, dtype=dtype) assert stride == 1 - self.causal_padding = (kernel_size - 1, 0) + self.causal_padding = kernel_size - 1 - def forward(self, x: torch.Tensor): - x = F.pad(x, self.causal_padding) + def forward(self, x: torch.Tensor, cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: + if cache.size(2) == 0: + x = F.pad(x, (self.causal_padding, 0), value=0.0) + else: + assert cache.size(2) == self.causal_padding + x = torch.concat([cache, x], dim=2) + cache = x[:, :, -self.causal_padding:] x = super(CausalConv1d, self).forward(x) - return x + return x, cache + + +class CausalBlock1D(Block1D): + def __init__(self, dim: int, dim_out: int): + super(CausalBlock1D, self).__init__(dim, dim_out) + self.block = torch.nn.Sequential( + CausalConv1d(dim, dim_out, 3), + Transpose(1, 2), + nn.LayerNorm(dim_out), + Transpose(1, 2), + nn.Mish(), + ) + + def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: + output, cache = self.block[0](x * mask, cache) + for i in range(1, len(self.block)): + output = self.block[i](output) + return output * mask, cache + + +class CausalResnetBlock1D(ResnetBlock1D): + def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): + super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) + self.block1 = CausalBlock1D(dim, dim_out) + self.block2 = CausalBlock1D(dim_out, dim_out) + + def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor, block1_cache: torch.Tensor=torch.zeros(0, 0, 0), block2_cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + h, block1_cache = self.block1(x, mask, block1_cache) + h += self.mlp(time_emb).unsqueeze(-1) + h, block2_cache = self.block2(h, mask, block2_cache) + output = h + self.res_conv(x * mask) + return output, block1_cache, block2_cache + + +class CausalAttnProcessor2_0(AttnProcessor2_0): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + super(CausalAttnProcessor2_0, self).__init__() + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + *args, + **kwargs, + ) -> Tuple[torch.FloatTensor, torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_cache = attn.to_k(encoder_hidden_states) + value_cache = attn.to_v(encoder_hidden_states) + # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2) + if cache.size(0) != 0: + key = torch.concat([cache[:, :, :, 0], key_cache], dim=1) + value = torch.concat([cache[:, :, :, 1], value_cache], dim=1) + else: + key, value = key_cache, value_cache + cache = torch.stack([key_cache, value_cache], dim=3) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states, cache + + +@maybe_allow_in_graph +class CausalAttention(Attention): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor2_0"] = None, + out_dim: int = None, + ): + super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, + added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim) + processor = CausalAttnProcessor2_0() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + **cross_attention_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cache=cache, + **cross_attention_kwargs, + ) + + +@maybe_allow_in_graph +class CausalBasicTransformerBlock(BasicTransformerBlock): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, activation_fn, num_embeds_ada_norm, + attention_bias, only_cross_attention, double_self_attention, upcast_attention, norm_elementwise_affine, norm_type, final_dropout) + self.attn1 = CausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output, cache = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + cache=cache, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states, cache class ConditionalDecoder(nn.Module): @@ -89,7 +388,6 @@ class ConditionalDecoder(nn.Module): self, in_channels, out_channels, - causal=False, channels=(256, 256), dropout=0.05, attention_head_dim=64, @@ -106,7 +404,7 @@ class ConditionalDecoder(nn.Module): channels = tuple(channels) self.in_channels = in_channels self.out_channels = out_channels - self.causal = causal + self.time_embeddings = SinusoidalPosEmb(in_channels) time_embed_dim = channels[0] * 4 self.time_mlp = TimestepEmbedding( @@ -123,8 +421,7 @@ class ConditionalDecoder(nn.Module): input_channel = output_channel output_channel = channels[i] is_last = i == len(channels) - 1 - resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ - ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( @@ -138,16 +435,14 @@ class ConditionalDecoder(nn.Module): ] ) downsample = ( - Downsample1D(output_channel) if not is_last else - CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) for _ in range(num_mid_blocks): input_channel = channels[-1] out_channels = channels[-1] - resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ - ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ @@ -169,11 +464,7 @@ class ConditionalDecoder(nn.Module): input_channel = channels[i] * 2 output_channel = channels[i + 1] is_last = i == len(channels) - 2 - resnet = CausalResnetBlock1D( - dim=input_channel, - dim_out=output_channel, - time_emb_dim=time_embed_dim, - ) if self.causal else ResnetBlock1D( + resnet = ResnetBlock1D( dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim, @@ -193,10 +484,10 @@ class ConditionalDecoder(nn.Module): upsample = ( Upsample1D(output_channel, use_conv_transpose=True) if not is_last - else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) + else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) - self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) + self.final_block = Block1D(channels[-1], channels[-1]) self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.initialize_weights() @@ -249,9 +540,8 @@ class ConditionalDecoder(nn.Module): mask_down = masks[-1] x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) - attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + attn_mask = (torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) == 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -268,9 +558,8 @@ class ConditionalDecoder(nn.Module): for resnet, transformer_blocks in self.mid_blocks: x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) - attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + attn_mask = (torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) == 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -285,9 +574,8 @@ class ConditionalDecoder(nn.Module): x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) - attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) + attn_mask = (torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) == 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -299,3 +587,296 @@ class ConditionalDecoder(nn.Module): x = self.final_block(x, mask_up) output = self.final_proj(x * mask_up) return output * mask + + +class CausalConditionalDecoder(ConditionalDecoder): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + static_chunk_size=50, + num_decoding_left_chunks=2, + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + torch.nn.Module.__init__(self) + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.static_chunk_size = static_chunk_size + self.num_decoding_left_chunks = num_decoding_left_chunks + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + CausalBasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + CausalBasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = CausalResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + CausalBasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else CausalConv1d(output_channel, output_channel, 3) + ) + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + self.final_block = CausalBlock1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x, _, _ = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x, _ = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x, _ = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x, _, _ = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x, _ = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x, _, _ = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x, _ = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x, _ = upsample(x * mask_up) + x, _ = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask + + def forward_chunk(self, x, mask, mu, t, spks=None, cond=None, + down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), + mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), + up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), + final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0) + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + + down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device) + mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device) + up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device) + for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks): + mask_down = masks[-1] + x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576]) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool() + attn_mask = mask_to_bias(attn_mask, x.dtype) + for i, transformer_block in enumerate(transformer_blocks): + x, down_blocks_kv_cache_new[index, i] = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + cache=down_blocks_kv_cache[index, i], + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:]) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks): + x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:]) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool() + attn_mask = mask_to_bias(attn_mask, x.dtype) + for i, transformer_block in enumerate(transformer_blocks): + x, mid_blocks_kv_cache_new[index, i] = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + cache=mid_blocks_kv_cache[index, i] + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks): + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768]) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool() + attn_mask = mask_to_bias(attn_mask, x.dtype) + for i, transformer_block in enumerate(transformer_blocks): + x, up_blocks_kv_cache_new[index, i] = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + cache=up_blocks_kv_cache[index, i] + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:]) + x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache) + output = self.final_proj(x * mask_up) + return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 72bb34c..71f5ae7 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -91,6 +91,7 @@ class MaskedDiffWithXvec(torch.nn.Module): conds = conds.transpose(1, 2) mask = (~make_pad_mask(feat_len)).to(h) + # NOTE 这一句应该是不需要的,应该h已经过length_regulator跟feat一样的shape了 feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) loss, _ = self.decoder.compute_loss( feat.transpose(1, 2).contiguous(), @@ -190,6 +191,49 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): self.token_mel_ratio = token_mel_ratio self.pre_lookahead_len = pre_lookahead_len + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + token = batch['speech_token'].to(device) + token_len = batch['speech_token_len'].to(device) + feat = batch['speech_feat'].to(device) + feat_len = batch['speech_feat_len'].to(device) + embedding = batch['embedding'].to(device) + + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + + # get conditions + feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) + conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h) + loss, _ = self.decoder.compute_loss( + feat.transpose(1, 2).contiguous(), + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + embedding, + cond=conds + ) + return {'loss': loss} + @torch.inference_mode() def inference(self, token, @@ -199,6 +243,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): prompt_feat, prompt_feat_len, embedding, + cache, finalize): if self.fp16 is True: prompt_feat = prompt_feat.half() @@ -215,9 +260,17 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): token = self.input_embedding(torch.clamp(token, min=0)) * mask # text encode - h, h_lengths = self.encoder(token, token_len) - if finalize is False: - h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] + if finalize is True: + h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache']) + else: + token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:] + h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache']) + cache['encoder_cache']['offset'] = encoder_cache[0] + cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1] + cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2] + cache['encoder_cache']['upsample_offset'] = encoder_cache[3] + cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4] + cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5] mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] h = self.encoder_proj(h) @@ -227,13 +280,14 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): conds = conds.transpose(1, 2) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - feat, _ = self.decoder( + feat, cache['decoder_cache'] = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, cond=conds, - n_timesteps=10 + n_timesteps=10, + cache=cache['decoder_cache'] ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 - return feat.float(), None + return feat.float(), cache diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 6a60f6d..3a7de2e 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -34,7 +34,7 @@ class ConditionalCFM(BASECFM): self.lock = threading.Lock() @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)): """Forward diffusion Args: @@ -54,19 +54,19 @@ class ConditionalCFM(BASECFM): """ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature - cache_size = flow_cache.shape[2] + cache_size = cache.shape[2] # fix prompt and overlap part mu and z if cache_size != 0: - z[:, :, :cache_size] = flow_cache[:, :, :, 0] - mu[:, :, :cache_size] = flow_cache[:, :, :, 1] + z[:, :, :cache_size] = cache[:, :, :, 0] + mu[:, :, :cache_size] = cache[:, :, :, 1] z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) - flow_cache = torch.stack([z_cache, mu_cache], dim=-1) + cache = torch.stack([z_cache, mu_cache], dim=-1) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache def solve_euler(self, x, t_span, mu, mask, spks, cond): """ @@ -123,7 +123,7 @@ class ConditionalCFM(BASECFM): def forward_estimator(self, x, mask, mu, t, spks, cond): if isinstance(self.estimator, torch.nn.Module): - return self.estimator.forward(x, mask, mu, t, spks, cond) + return self.estimator(x, mask, mu, t, spks, cond) else: with self.lock: self.estimator.set_input_shape('x', (2, 80, x.size(2))) @@ -181,6 +181,9 @@ class ConditionalCFM(BASECFM): pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) + if loss.isnan(): + print(123) + pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond) return loss, y @@ -190,7 +193,7 @@ class CausalConditionalCFM(ConditionalCFM): self.rand_noise = torch.randn([1, 80, 50 * 300]) @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}): """Forward diffusion Args: @@ -209,9 +212,105 @@ class CausalConditionalCFM(ConditionalCFM): shape: (batch_size, n_feats, mel_timesteps) """ - z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature + offset = cache.pop('offset') + z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature + z = z[:, :, offset:] + offset += mu.size(2) # fix prompt and overlap part mu and z t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None + mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache) + cache['offset'] = offset + return mel, cache + + def solve_euler(self, x, t_span, mu, mask, spks, cond, cache): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + # estimator cache for each step + down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device) + mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x.device) + up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device) + + # Do not use concat, it may cause memory format changed and trt infer with wrong results! + x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) + mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + t_in = torch.zeros([2], device=x.device, dtype=x.dtype) + spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) + cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) + for step in range(1, len(t_span)): + # Classifier-Free Guidance inference introduced in VoiceBox + x_in[:] = x + mask_in[:] = mask + mu_in[0] = mu + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond + cache_step = {k: v[step - 1] for k, v in cache.items()} + dphi_dt, cache_step = self.forward_estimator( + x_in, mask_in, + mu_in, t_in, + spks_in, + cond_in, + cache_step + ) + cache['down_blocks_conv_cache'][step - 1] = cache_step[0] + down_blocks_kv_cache_new[step - 1] = cache_step[1] + cache['mid_blocks_conv_cache'][step - 1] = cache_step[2] + mid_blocks_kv_cache_new[step - 1] = cache_step[3] + cache['up_blocks_conv_cache'][step - 1] = cache_step[4] + up_blocks_kv_cache_new[step - 1] = cache_step[5] + cache['final_blocks_conv_cache'][step - 1] = cache_step[6] + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4) + cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4) + cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4) + return sol[-1].float(), cache + + def forward_estimator(self, x, mask, mu, t, spks, cond, cache): + if isinstance(self.estimator, torch.nn.Module): + x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache) + cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7) + else: + with self.lock: + self.estimator.set_input_shape('x', (2, 80, x.size(2))) + self.estimator.set_input_shape('mask', (2, 1, x.size(2))) + self.estimator.set_input_shape('mu', (2, 80, x.size(2))) + self.estimator.set_input_shape('t', (2,)) + self.estimator.set_input_shape('spks', (2, 80)) + self.estimator.set_input_shape('cond', (2, 80, x.size(2))) + # run trt engine + self.estimator.execute_v2([x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()]) + return x, cache diff --git a/cosyvoice/hifigan/discriminator.py b/cosyvoice/hifigan/discriminator.py index 1a4dcc8..5435660 100644 --- a/cosyvoice/hifigan/discriminator.py +++ b/cosyvoice/hifigan/discriminator.py @@ -1,6 +1,9 @@ import torch import torch.nn as nn -from torch.nn.utils.parametrizations import weight_norm +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm from typing import List, Optional, Tuple from einops import rearrange from torchaudio.transforms import Spectrogram diff --git a/cosyvoice/hifigan/f0_predictor.py b/cosyvoice/hifigan/f0_predictor.py index 172c5f5..5797c31 100644 --- a/cosyvoice/hifigan/f0_predictor.py +++ b/cosyvoice/hifigan/f0_predictor.py @@ -13,7 +13,10 @@ # limitations under the License. import torch import torch.nn as nn -from torch.nn.utils.parametrizations import weight_norm +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm class ConvRNNF0Predictor(nn.Module): diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index c47bf05..50d7f99 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -23,7 +23,10 @@ import torch.nn.functional as F from torch.nn import Conv1d from torch.nn import ConvTranspose1d from torch.nn.utils import remove_weight_norm -from torch.nn.utils.parametrizations import weight_norm +try: + from torch.nn.utils.parametrizations import weight_norm +except ImportError: + from torch.nn.utils import weight_norm from torch.distributions.uniform import Uniform from cosyvoice.transformer.activation import Snake diff --git a/cosyvoice/transformer/embedding.py b/cosyvoice/transformer/embedding.py index eae8c8e..ba20d71 100644 --- a/cosyvoice/transformer/embedding.py +++ b/cosyvoice/transformer/embedding.py @@ -287,8 +287,16 @@ class EspnetRelPositionalEncoding(torch.nn.Module): Returns: torch.Tensor: Corresponding encoding """ - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, - ] + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + if isinstance(offset, int): + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset, + ] + elif isinstance(offset, torch.Tensor): + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset, + ] return pos_emb diff --git a/cosyvoice/transformer/upsample_encoder.py b/cosyvoice/transformer/upsample_encoder.py index f67fb98..6032cac 100644 --- a/cosyvoice/transformer/upsample_encoder.py +++ b/cosyvoice/transformer/upsample_encoder.py @@ -56,11 +56,16 @@ class Upsample1D(nn.Module): # In this mode, first repeat interpolate, than conv with stride=1 self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) - def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): + def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") - outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) + if conv_cache.size(2) == 0: + outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) + else: + assert conv_cache.size(2) == self.stride * 2 + outputs = torch.concat([conv_cache, outputs], dim=2) + conv_cache_new = outputs[:, :, -self.stride * 2:] outputs = self.conv(outputs) - return outputs, input_lengths * self.stride + return outputs, input_lengths * self.stride, conv_cache_new class PreLookaheadLayer(nn.Module): @@ -78,22 +83,32 @@ class PreLookaheadLayer(nn.Module): kernel_size=3, stride=1, padding=0, ) - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: """ inputs: (batch_size, seq_len, channels) """ outputs = inputs.transpose(1, 2).contiguous() + context = context.transpose(1, 2).contiguous() # look ahead - outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) + if context.size(2) == 0: + outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) + else: + assert context.size(2) == self.pre_lookahead_len + outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0) outputs = F.leaky_relu(self.conv1(outputs)) # outputs - outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0) + if conv2_cache.size(2) == 0: + outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0) + else: + assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1 + outputs = torch.concat([conv2_cache, outputs], dim=2) + conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):] outputs = self.conv2(outputs) outputs = outputs.transpose(1, 2).contiguous() # residual connection outputs = outputs + inputs - return outputs + return outputs, conv2_cache_new class UpsampleConformerEncoder(torch.nn.Module): @@ -277,12 +292,12 @@ class UpsampleConformerEncoder(torch.nn.Module): self.static_chunk_size, num_decoding_left_chunks) # lookahead + conformer encoder - xs = self.pre_lookahead_layer(xs) + xs, _ = self.pre_lookahead_layer(xs) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) # upsample + conformer encoder xs = xs.transpose(1, 2).contiguous() - xs, xs_lens = self.up_layer(xs, xs_lens) + xs, xs_lens, _ = self.up_layer(xs, xs_lens) xs = xs.transpose(1, 2).contiguous() T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) @@ -316,3 +331,99 @@ class UpsampleConformerEncoder(torch.nn.Module): for layer in self.up_encoders: xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs + + @torch.jit.export + def forward_chunk( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + offset: int = 0, + context: torch.Tensor = torch.zeros(0, 0, 0), + pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0), + encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0), + upsample_offset: int = 0, + upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0), + upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0) + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + offset += xs.size(1) + tmp_masks = torch.ones(1, + context.size(1), + device=context.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if context.size(1) != 0: + context, _, _ = self.embed(context, tmp_masks, offset) + + # lookahead + conformer encoder + xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache) + # NOTE in cache mode we do not need to call add_optional_chunk_mask + chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device) + mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device) + encoders_kv_cache_list = [] + for index, layer in enumerate(self.encoders): + xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index]) + encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0) + + # upsample + xs = xs.transpose(1, 2).contiguous() + xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache) + xs = xs.transpose(1, 2).contiguous() + + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset) + upsample_offset += xs.size(1) + + # conformer encoder + chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device) + mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device) + upsample_kv_cache_list = [] + for index, layer in enumerate(self.up_encoders): + xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index]) + upsample_kv_cache_list.append(upsample_kv_cache_new) + upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0) + + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache_new, upsample_offset, upsample_conv_cache, upsample_kv_cache_new) diff --git a/cosyvoice/utils/mask.py b/cosyvoice/utils/mask.py index c164db1..dabd19d 100644 --- a/cosyvoice/utils/mask.py +++ b/cosyvoice/utils/mask.py @@ -87,7 +87,7 @@ def subsequent_mask( return mask -def subsequent_chunk_mask_deprecated( +def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, @@ -125,41 +125,6 @@ def subsequent_chunk_mask_deprecated( return ret -def subsequent_chunk_mask( - size: int, - chunk_size: int, - num_left_chunks: int = -1, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """Create mask for subsequent steps (size, size) with chunk size, - this is for streaming encoder - - Args: - size (int): size of mask - chunk_size (int): size of chunk - num_left_chunks (int): number of left chunks - <0: use full chunk - >=0: use num_left_chunks - device (torch.device): "cpu" or "cuda" or torch.Tensor.device - - Returns: - torch.Tensor: mask - - Examples: - >>> subsequent_chunk_mask(4, 2) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1]] - """ - # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks - # actually this is not needed after we have inference cache implemented, will remove it later - pos_idx = torch.arange(size, device=device) - block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size - ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) - return ret - - def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, use_dynamic_chunk: bool, diff --git a/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml b/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml new file mode 100644 index 0000000..f989e3e --- /dev/null +++ b/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml @@ -0,0 +1,232 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1986] +__set_seed2: !apply:numpy.random.seed [1986] +__set_seed3: !apply:torch.manual_seed [1986] +__set_seed4: !apply:torch.cuda.manual_seed_all [1986] + +# fixed params +sample_rate: 24000 +llm_input_size: 896 +llm_output_size: 896 +spk_embed_dim: 192 +qwen_pretrain_path: '' +token_frame_rate: 25 +token_mel_ratio: 2 + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:cosyvoice.llm.llm.Qwen2LM + llm_input_size: !ref + llm_output_size: !ref + speech_token_size: 6561 + length_normalized_loss: True + lsm_weight: 0 + mix_ratio: [5, 15] + llm: !new:cosyvoice.llm.llm.Qwen2Encoder + pretrain_path: !ref + sampling: !name:cosyvoice.utils.common.ras_sampling + top_p: 0.8 + top_k: 25 + win_size: 10 + tau_r: 0.1 + +flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec + input_size: 512 + output_size: 80 + spk_embed_dim: !ref + output_type: 'mel' + vocab_size: 6561 + input_frame_rate: !ref + only_mask_loss: True + token_mel_ratio: !ref + pre_lookahead_len: 3 + encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder + output_size: 512 + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 512 + use_cnn_module: False + macaron_style: False + use_dynamic_chunk: True + decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM + in_channels: 240 + n_spks: 1 + spk_emb_dim: 80 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder + in_channels: 320 + out_channels: 80 + channels: [256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 12 + num_heads: 8 + act_fn: 'gelu' + static_chunk_size: !ref * # here we use static_chunk_size because we want to fix kv cache size during inference + num_decoding_left_chunks: 2 + +hift: !new:cosyvoice.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 5, 3] + upsample_kernel_sizes: [16, 11, 7] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +# gan related module +mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: null + center: False +hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan + generator: !ref + discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator + mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator + mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator + mel_spec_transform: [ + !ref + ] + +# processor functions +parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener +get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer + token_path: !ref + skip_special_tokens: True +allowed_special: 'all' +tokenize: !name:cosyvoice.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:cosyvoice.dataset.processor.filter + max_length: 40960 + min_length: 100 + token_max_length: 200 + token_min_length: 1 +resample: !name:cosyvoice.dataset.processor.resample + resample_rate: !ref +truncate: !name:cosyvoice.dataset.processor.truncate + truncate_length: 24480 # must be a multiplier of hop_size +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1920 + num_mels: 80 + sampling_rate: !ref + hop_size: 480 + win_size: 1920 + fmin: 0 + fmax: 8000 + center: False +compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank + feat_extractor: !ref +# pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch # TODO need to replace it +# sample_rate: !ref +# frame_length: 46.4 # match feat_extractor win_size/sampling_rate +# frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate +# compute_f0: !name:cosyvoice.dataset.processor.compute_f0 +# pitch_extractor: !ref +parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding + normalize: True +shuffle: !name:cosyvoice.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:cosyvoice.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:cosyvoice.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 2500 +padding: !name:cosyvoice.dataset.processor.padding + use_spk_embedding: False # change to True during sft + + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] +# data_pipeline_gan: [ +# !ref , +# !ref , +# !ref , +# !ref , +# !ref , +# !ref , +# !ref , +# !ref , +# !ref , +# !ref , +# !ref , +# !ref , +# ] + +# llm flow train conf +train_conf: + optim: adam + optim_conf: + lr: 1e-5 # change to 1e-5 during sft + scheduler: constantlr # change to constantlr during sft + scheduler_conf: + warmup_steps: 2500 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: -1 + +# gan train conf +train_conf_gan: + optim: adam + optim_conf: + lr: 0.0002 # use small lr for gan training + scheduler: constantlr + optim_d: adam + optim_conf_d: + lr: 0.0002 # use small lr for gan training + scheduler_d: constantlr + max_epoch: 200 + grad_clip: 5 + accum_grad: 1 # in gan training, accum_grad must be 1 + log_interval: 100 + save_per_step: -1 \ No newline at end of file diff --git a/examples/libritts/cosyvoice2/conf/ds_stage2.json b/examples/libritts/cosyvoice2/conf/ds_stage2.json new file mode 100644 index 0000000..2b2de3d --- /dev/null +++ b/examples/libritts/cosyvoice2/conf/ds_stage2.json @@ -0,0 +1,42 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 5, + "fp16": { + "enabled": false, + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 256, + "hysteresis": 2, + "consecutive_hysteresis": false, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": false + }, + "zero_force_ds_cpu_optimizer": false, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": false, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients" : true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 0.001, + "weight_decay": 0.0001, + "torch_adam": true, + "adam_w_mode": true + } + } +} \ No newline at end of file diff --git a/examples/libritts/cosyvoice2/path.sh b/examples/libritts/cosyvoice2/path.sh new file mode 100644 index 0000000..e0fa06c --- /dev/null +++ b/examples/libritts/cosyvoice2/path.sh @@ -0,0 +1,3 @@ +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH diff --git a/examples/libritts/cosyvoice2/run.sh b/examples/libritts/cosyvoice2/run.sh new file mode 100644 index 0000000..4168a78 --- /dev/null +++ b/examples/libritts/cosyvoice2/run.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# Copyright 2024 Alibaba Inc. All Rights Reserved. +. ./path.sh || exit 1; + +stage=-1 +stop_stage=3 + +data_url=www.openslr.org/resources/60 +data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts +pretrained_model_dir=/mnt/lyuxiang.lx/data/tts/models/IIC/CosyVoice2-0.5B/ + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "Data Download" + for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + local/download_and_untar.sh ${data_dir} ${data_url} ${part} + done +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" + for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mkdir -p data/$x + python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" + for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + tools/extract_embedding.py --dir data/$x \ + --onnx_path $pretrained_model_dir/campplus.onnx + done +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" + for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + tools/extract_speech_token.py --dir data/$x \ + --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx + done +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" + for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mkdir -p data/$x/parquet + tools/make_parquet_list.py --num_utts_per_parquet 1000 \ + --num_processes 10 \ + --src_dir data/$x \ + --des_dir data/$x/parquet + done +fi + +# inference +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "Run inference. Please make sure utt in tts_text is in prompt_data" + # TODO consider remove bin/inference.py, or use similar initilization method as in readme + for mode in sft zero_shot; do + python cosyvoice/bin/inference.py --mode $mode \ + --gpu 0 \ + --config conf/cosyvoice2.yaml \ + --prompt_data data/test-clean/parquet/data.list \ + --prompt_utt2data data/test-clean/parquet/utt2data.list \ + --tts_text `pwd`/tts_text.json \ + --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ + --llm_model $pretrained_model_dir/llm.pt \ + --flow_model $pretrained_model_dir/flow.pt \ + --hifigan_model $pretrained_model_dir/hift.pt \ + --result_dir `pwd`/exp/cosyvoice/test-clean/$mode + done +fi + +# train llm +export CUDA_VISIBLE_DEVICES="2,3,4,5,6,7" +num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +job_id=1986 +dist_backend="nccl" +num_workers=2 +prefetch=100 +train_engine=torch_ddp +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml" + if [ $train_engine == 'deepspeed' ]; then + echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" + fi + cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list + cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list + # NOTE will update llm/hift training later + for model in flow; do + torchrun --nnodes=1 --nproc_per_node=$num_gpus \ + --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ + cosyvoice/bin/train.py \ + --train_engine $train_engine \ + --config conf/cosyvoice2.yaml \ + --train_data data/train.data.list \ + --cv_data data/dev.data.list \ + --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ + --model $model \ + --checkpoint $pretrained_model_dir/$model.pt \ + --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \ + --tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \ + --ddp.dist_backend $dist_backend \ + --num_workers ${num_workers} \ + --prefetch ${prefetch} \ + --pin_memory \ + --use_amp \ + --deepspeed_config ./conf/ds_stage2.json \ + --deepspeed.save_states model+optimizer + done +fi + +# average model +average_num=5 +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + for model in llm flow hifigan; do + decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt + echo "do model average and final checkpoint is $decode_checkpoint" + python cosyvoice/bin/average_model.py \ + --dst_model $decode_checkpoint \ + --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ + --num ${average_num} \ + --val_best + done +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" + python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir + python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir +fi \ No newline at end of file diff --git a/examples/libritts/cosyvoice2/tts_text.json b/examples/libritts/cosyvoice2/tts_text.json new file mode 100644 index 0000000..9f3e8d9 --- /dev/null +++ b/examples/libritts/cosyvoice2/tts_text.json @@ -0,0 +1,5 @@ +{ + "1089_134686_000002_000000": [ + "hello, my name is Jack. What is your name?" + ] +} \ No newline at end of file