From 68100c267a0a4a01e88bb52511f64d1bd97c21fd Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Fri, 23 May 2025 12:50:47 +0800 Subject: [PATCH] remove flow_cache --- README.md | 2 +- cosyvoice/bin/export_jit.py | 7 +- cosyvoice/bin/export_onnx.py | 169 +++----- cosyvoice/cli/cosyvoice.py | 6 +- cosyvoice/cli/model.py | 100 ++--- cosyvoice/dataset/processor.py | 13 +- cosyvoice/flow/decoder.py | 459 ++-------------------- cosyvoice/flow/flow.py | 16 +- cosyvoice/flow/flow_matching.py | 168 ++------ cosyvoice/hifigan/generator.py | 170 +++++++- cosyvoice/transformer/upsample_encoder.py | 132 +------ cosyvoice/utils/file_utils.py | 2 +- cosyvoice/utils/mask.py | 39 +- test1.py | 37 -- 14 files changed, 365 insertions(+), 955 deletions(-) delete mode 100644 test1.py diff --git a/README.md b/README.md index 4a1dbd3..c7a724d 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ import torchaudio **CosyVoice2 Usage** ```python -cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False) +cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False) # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference # zero_shot usage diff --git a/cosyvoice/bin/export_jit.py b/cosyvoice/bin/export_jit.py index 1e89005..4eedc1a 100644 --- a/cosyvoice/bin/export_jit.py +++ b/cosyvoice/bin/export_jit.py @@ -61,8 +61,7 @@ def main(): model = CosyVoice(args.model_dir) except Exception: try: - # NOTE set use_flow_cache=True when export jit for cache inference - model = CosyVoice2(args.model_dir, use_flow_cache=True) + model = CosyVoice2(args.model_dir) except Exception: raise TypeError('no valid model_type!') @@ -93,9 +92,9 @@ def main(): else: # 3. export flow encoder flow_encoder = model.model.flow.encoder - script = get_optimized_script(flow_encoder, ['forward_chunk']) + script = get_optimized_script(flow_encoder) script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) - script = get_optimized_script(flow_encoder.half(), ['forward_chunk']) + script = get_optimized_script(flow_encoder.half()) script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) logging.info('successfully export flow_encoder') diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py index fcb1594..dd9f009 100644 --- a/cosyvoice/bin/export_onnx.py +++ b/cosyvoice/bin/export_onnx.py @@ -62,135 +62,58 @@ def main(): model = CosyVoice(args.model_dir) except Exception: try: - # NOTE set use_flow_cache=True when export jit for cache inference - model = CosyVoice2(args.model_dir, use_flow_cache=True) + model = CosyVoice2(args.model_dir) except Exception: raise TypeError('no valid model_type!') - if not isinstance(model, CosyVoice2): - # 1. export flow decoder estimator - estimator = model.model.flow.decoder.estimator - estimator.eval() + # 1. export flow decoder estimator + estimator = model.model.flow.decoder.estimator + estimator.eval() - device = model.model.device - batch_size, seq_len = 2, 256 - out_channels = model.model.flow.decoder.estimator.out_channels - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) - torch.onnx.export( - estimator, - (x, mask, mu, t, spks, cond), - '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], - output_names=['estimator_out'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'estimator_out': {2: 'seq_len'}, - } - ) + device = model.model.device + batch_size, seq_len = 2, 256 + out_channels = model.model.flow.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'estimator_out': {2: 'seq_len'}, + } + ) - # 2. test computation consistency - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] - estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - sess_options=option, providers=providers) + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + sess_options=option, providers=providers) - for _ in tqdm(range(10)): - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) - output_pytorch = estimator(x, mask, mu, t, spks, cond) - ort_inputs = { - 'x': x.cpu().numpy(), - 'mask': mask.cpu().numpy(), - 'mu': mu.cpu().numpy(), - 't': t.cpu().numpy(), - 'spks': spks.cpu().numpy(), - 'cond': cond.cpu().numpy() - } - output_onnx = estimator_onnx.run(None, ort_inputs)[0] - torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) - logging.info('successfully export estimator') - else: - # 1. export flow decoder estimator - estimator = model.model.flow.decoder.estimator - estimator.forward = estimator.forward_chunk - estimator.eval() - - device = model.model.device - batch_size, seq_len = 2, 256 - out_channels = model.model.flow.decoder.estimator.out_channels - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) - cache = model.model.init_flow_cache()['decoder_cache'] - cache.pop('offset') - cache = {k: v[0] for k, v in cache.items()} - torch.onnx.export( - estimator, - (x, mask, mu, t, spks, cond, - cache['down_blocks_conv_cache'], - cache['down_blocks_kv_cache'], - cache['mid_blocks_conv_cache'], - cache['mid_blocks_kv_cache'], - cache['up_blocks_conv_cache'], - cache['up_blocks_kv_cache'], - cache['final_blocks_conv_cache']), - '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache', - 'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'], - output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out', - 'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'down_blocks_kv_cache': {3: 'cache_in_len'}, - 'mid_blocks_kv_cache': {3: 'cache_in_len'}, - 'up_blocks_kv_cache': {3: 'cache_in_len'}, - 'estimator_out': {2: 'seq_len'}, - 'down_blocks_kv_cache_out': {3: 'cache_out_len'}, - 'mid_blocks_kv_cache_out': {3: 'cache_out_len'}, - 'up_blocks_kv_cache_out': {3: 'cache_out_len'}, - } - ) - - # 2. test computation consistency - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] - estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - sess_options=option, providers=providers) - - for iter in tqdm(range(10)): - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) - cache = model.model.init_flow_cache()['decoder_cache'] - cache.pop('offset') - cache = {k: v[0] for k, v in cache.items()} - output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()}) - ort_inputs = { - 'x': x.cpu().numpy(), - 'mask': mask.cpu().numpy(), - 'mu': mu.cpu().numpy(), - 't': t.cpu().numpy(), - 'spks': spks.cpu().numpy(), - 'cond': cond.cpu().numpy(), - } - output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}}) - if iter == 0: - # NOTE why can not pass first iteration check? - continue - for i, j in zip(output_pytorch, output_onnx): - torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4) - logging.info('successfully export estimator') + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + logging.info('successfully export estimator') if __name__ == "__main__": diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 3b9a7d5..b95a9e0 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -140,7 +140,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False, trt_concurrent=1): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 @@ -162,9 +162,9 @@ class CosyVoice2(CosyVoice): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache, trt_concurrent) + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent) self.model.load('{}/llm.pt'.format(model_dir), - '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 104c217..aa110b1 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -33,12 +33,14 @@ class CosyVoiceModel: llm: torch.nn.Module, flow: torch.nn.Module, hift: torch.nn.Module, - fp16: bool = False): + fp16: bool = False, + trt_concurrent: int = 1): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 + self.trt_concurrent = trt_concurrent if self.fp16 is True: self.llm.half() self.flow.half() @@ -85,23 +87,18 @@ class CosyVoiceModel: def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' - if not os.path.exists(flow_decoder_estimator_model): + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) - if os.path.getsize(flow_decoder_estimator_model) == 0: - raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model)) del self.flow.decoder.estimator import tensorrt as trt with open(flow_decoder_estimator_model, 'rb') as f: estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - if isinstance(self, CosyVoice2Model): - self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent) - else: - self.flow.decoder.estimator = estimator_engine.create_execution_context() + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent) def get_trt_kwargs(self): min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] - opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)] + opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] input_names = ["x", "mask", "mu", "cond"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} @@ -249,21 +246,21 @@ class CosyVoice2Model(CosyVoiceModel): flow: torch.nn.Module, hift: torch.nn.Module, fp16: bool = False, - use_flow_cache: bool = False, trt_concurrent: int = 1): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow + # NOTE default setting for jit/onnx export, you can set to False when using pytorch inference + self.flow.encoder.streaming = True + self.flow.decoder.estimator.streaming = True self.hift = hift self.fp16 = fp16 - self.use_flow_cache = use_flow_cache self.trt_concurrent = trt_concurrent if self.fp16 is True: self.llm.half() self.flow.half() - # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml + # NOTE must matching training static_chunk_size self.token_hop_len = 25 - self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio # hift cache self.mel_cache_len = 8 self.source_cache_len = int(self.mel_cache_len * 480) @@ -278,56 +275,24 @@ class CosyVoice2Model(CosyVoiceModel): # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} - self.flow_cache_dict = {} self.hift_cache_dict = {} self.trt_context_dict = {} - def init_flow_cache(self): - encoder_cache = {'offset': 0, - 'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device), - 'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device), - 'upsample_offset': 0, - 'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device), - 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)} - decoder_cache = {'offset': 0, - 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device), - 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), - 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device), - 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), - 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device), - 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), - 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)} - if self.fp16 is True: - for cache in [encoder_cache, decoder_cache]: - for k, v in cache.items(): - if isinstance(v, torch.Tensor): - cache[k] = v.half() - cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache} - return cache - def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def get_trt_kwargs(self): - min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)] - opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)] - max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)] - input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache'] - assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape" - return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} - - def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0): with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]: - tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - cache=self.flow_cache_dict[uuid], - finalize=finalize) + tts_mel, _ = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + finalize=finalize) + tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] @@ -362,7 +327,6 @@ class CosyVoice2Model(CosyVoiceModel): with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None - self.flow_cache_dict[this_uuid] = self.init_flow_cache() self.trt_context_dict[this_uuid] = self.trt_context_pool.get() if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) @@ -370,27 +334,23 @@ class CosyVoice2Model(CosyVoiceModel): p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid)) p.start() if stream is True: - assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM" - # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size - flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):] - prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:] + token_offset = 0 + prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1]) while True: time.sleep(0.1) - if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) + this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len + if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, + token_offset=token_offset, uuid=this_uuid, finalize=False) - # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk - flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device) - prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device) + token_offset += this_token_hop_len yield {'tts_speech': this_tts_speech.cpu()} - with self.lock: - self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:] - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len: + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len: break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None @@ -399,18 +359,19 @@ class CosyVoice2Model(CosyVoiceModel): prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, + token_offset=token_offset, uuid=this_uuid, finalize=True) yield {'tts_speech': this_tts_speech.cpu()} else: # deal with all tokens - assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference" p.join() this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, + token_offset=0, uuid=this_uuid, finalize=True, speed=speed) @@ -419,7 +380,6 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - self.flow_cache_dict.pop(this_uuid) self.trt_context_pool.put(self.trt_context_dict[this_uuid]) self.trt_context_dict.pop(this_uuid) if torch.cuda.is_available(): diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 08030d6..a94eb15 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -159,7 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'): def compute_fbank(data, feat_extractor, - token_mel_ratio=2, + token_mel_ratio=0, mode='train'): """ Extract fbank @@ -176,12 +176,11 @@ def compute_fbank(data, assert 'text_token' in sample waveform = sample['speech'] feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) - - # trim to align speech_token and speech_feat - token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0]) - feat = feat[:token_mel_ratio * token_len] - sample["speech_token"] = sample["speech_token"][:token_len] - + if token_mel_ratio != 0: + # trim to align speech_token and speech_feat + token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0])) + feat = feat[:token_mel_ratio * token_len] + sample["speech_token"] = sample["speech_token"][:token_len] sample['speech_feat'] = feat yield sample diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 4a89fb1..9e28c3f 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Optional, Dict, Any +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import pack, rearrange, repeat -from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate from cosyvoice.utils.common import mask_to_bias from cosyvoice.utils.mask import add_optional_chunk_mask from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D -from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph +from matcha.models.components.transformer import BasicTransformerBlock class Transpose(torch.nn.Module): @@ -29,7 +28,7 @@ class Transpose(torch.nn.Module): self.dim0 = dim0 self.dim1 = dim1 - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.transpose(x, self.dim0, self.dim1) return x @@ -57,15 +56,10 @@ class CausalConv1d(torch.nn.Conv1d): assert stride == 1 self.causal_padding = kernel_size - 1 - def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: - if cache.size(2) == 0: - x = F.pad(x, (self.causal_padding, 0), value=0.0) - else: - assert cache.size(2) == self.causal_padding - x = torch.concat([cache, x], dim=2) - cache = x[:, :, -self.causal_padding:] + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, (self.causal_padding, 0), value=0.0) x = super(CausalConv1d, self).forward(x) - return x, cache + return x class CausalBlock1D(Block1D): @@ -79,11 +73,9 @@ class CausalBlock1D(Block1D): nn.Mish(), ) - def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: - output, cache = self.block[0](x * mask, cache) - for i in range(1, len(self.block)): - output = self.block[i](output) - return output * mask, cache + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + output = self.block(x * mask) + return output * mask class CausalResnetBlock1D(ResnetBlock1D): @@ -92,303 +84,6 @@ class CausalResnetBlock1D(ResnetBlock1D): self.block1 = CausalBlock1D(dim, dim_out) self.block2 = CausalBlock1D(dim_out, dim_out) - def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor, - block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0) - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - h, block1_cache = self.block1(x, mask, block1_cache) - h += self.mlp(time_emb).unsqueeze(-1) - h, block2_cache = self.block2(h, mask, block2_cache) - output = h + self.res_conv(x * mask) - return output, block1_cache, block2_cache - - -class CausalAttnProcessor2_0(AttnProcessor2_0): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - super(CausalAttnProcessor2_0, self).__init__() - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - *args, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.Tensor]: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \ - `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key_cache = attn.to_k(encoder_hidden_states) - value_cache = attn.to_v(encoder_hidden_states) - # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2) - if cache.size(0) != 0: - key = torch.concat([cache[:, :, :, 0], key_cache], dim=1) - value = torch.concat([cache[:, :, :, 1], value_cache], dim=1) - else: - key, value = key_cache, value_cache - cache = torch.stack([key_cache, value_cache], dim=3) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states, cache - - -@maybe_allow_in_graph -class CausalAttention(Attention): - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - qk_norm: Optional[str] = None, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor2_0"] = None, - out_dim: int = None, - ): - super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, - cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups, - spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, - _from_deprecated_attn_block, processor, out_dim) - processor = CausalAttnProcessor2_0() - self.set_processor(processor) - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - **cross_attention_kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - The forward method of the `Attention` class. - - Args: - hidden_states (`torch.Tensor`): - The hidden states of the query. - encoder_hidden_states (`torch.Tensor`, *optional*): - The hidden states of the encoder. - attention_mask (`torch.Tensor`, *optional*): - The attention mask to use. If `None`, no mask is applied. - **cross_attention_kwargs: - Additional keyword arguments to pass along to the cross attention. - - Returns: - `torch.Tensor`: The output of the attention layer. - """ - # The `Attention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] - if len(unused_kwargs) > 0: - logger.warning( - f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} - - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cache=cache, - **cross_attention_kwargs, - ) - - -@maybe_allow_in_graph -class CausalBasicTransformerBlock(BasicTransformerBlock): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - ): - super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout, - cross_attention_dim, activation_fn, num_embeds_ada_norm, - attention_bias, only_cross_attention, double_self_attention, - upcast_attention, norm_elementwise_affine, norm_type, final_dropout) - self.attn1 = CausalAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - attn_output, cache = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, - cache=cache, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \ - {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.") - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states, cache - class ConditionalDecoder(nn.Module): def __init__( @@ -640,7 +335,7 @@ class CausalConditionalDecoder(ConditionalDecoder): resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ - CausalBasicTransformerBlock( + BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, @@ -662,7 +357,7 @@ class CausalConditionalDecoder(ConditionalDecoder): transformer_blocks = nn.ModuleList( [ - CausalBasicTransformerBlock( + BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, @@ -687,7 +382,7 @@ class CausalConditionalDecoder(ConditionalDecoder): ) transformer_blocks = nn.ModuleList( [ - CausalBasicTransformerBlock( + BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, @@ -724,6 +419,9 @@ class CausalConditionalDecoder(ConditionalDecoder): Returns: _type_: _description_ """ + if hasattr(self, 'streaming'): + assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode' + streaming = self.streaming t = self.time_embeddings(t).to(t.dtype) t = self.time_mlp(t) @@ -740,36 +438,36 @@ class CausalConditionalDecoder(ConditionalDecoder): masks = [mask] for resnet, transformer_blocks, downsample in self.down_blocks: mask_down = masks[-1] - x, _, _ = resnet(x, mask_down, t) + x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() if streaming is True: - attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) else: attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x, _ = transformer_block( + x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() hiddens.append(x) # Save hidden states for skip connections - x, _ = downsample(x * mask_down) + x = downsample(x * mask_down) masks.append(mask_down[:, :, ::2]) masks = masks[:-1] mask_mid = masks[-1] for resnet, transformer_blocks in self.mid_blocks: - x, _, _ = resnet(x, mask_mid, t) + x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() if streaming is True: - attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) else: attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x, _ = transformer_block( + x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, @@ -780,124 +478,21 @@ class CausalConditionalDecoder(ConditionalDecoder): mask_up = masks.pop() skip = hiddens.pop() x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] - x, _, _ = resnet(x, mask_up, t) + x = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() if streaming is True: - attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) else: attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x, _ = transformer_block( + x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() - x, _ = upsample(x * mask_up) - x, _ = self.final_block(x, mask_up) + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) output = self.final_proj(x * mask_up) return output * mask - - @torch.inference_mode() - def forward_chunk(self, x, mask, mu, t, spks=None, cond=None, - down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), - mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), - up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), - final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0) - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Forward pass of the UNet1DConditional model. - - Args: - x (torch.Tensor): shape (batch_size, in_channels, time) - mask (_type_): shape (batch_size, 1, time) - t (_type_): shape (batch_size) - spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. - cond (_type_, optional): placeholder for future use. Defaults to None. - - Raises: - ValueError: _description_ - ValueError: _description_ - - Returns: - _type_: _description_ - """ - - t = self.time_embeddings(t).to(t.dtype) - t = self.time_mlp(t) - - x = pack([x, mu], "b * t")[0] - - if spks is not None: - spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) - x = pack([x, spks], "b * t")[0] - if cond is not None: - x = pack([x, cond], "b * t")[0] - - hiddens = [] - masks = [mask] - - down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device) - mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device) - up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device) - for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks): - mask_down = masks[-1] - x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \ - resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576]) - x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool() - attn_mask = mask_to_bias(attn_mask, x.dtype) - for i, transformer_block in enumerate(transformer_blocks): - x, down_blocks_kv_cache_new[index, i] = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - cache=down_blocks_kv_cache[index, i], - ) - x = rearrange(x, "b t c -> b c t").contiguous() - hiddens.append(x) # Save hidden states for skip connections - x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:]) - masks.append(mask_down[:, :, ::2]) - masks = masks[:-1] - mask_mid = masks[-1] - - for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks): - x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \ - resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:]) - x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool() - attn_mask = mask_to_bias(attn_mask, x.dtype) - for i, transformer_block in enumerate(transformer_blocks): - x, mid_blocks_kv_cache_new[index, i] = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - cache=mid_blocks_kv_cache[index, i] - ) - x = rearrange(x, "b t c -> b c t").contiguous() - - for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks): - mask_up = masks.pop() - skip = hiddens.pop() - x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] - x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \ - resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768]) - x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool() - attn_mask = mask_to_bias(attn_mask, x.dtype) - for i, transformer_block in enumerate(transformer_blocks): - x, up_blocks_kv_cache_new[index, i] = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - cache=up_blocks_kv_cache[index, i] - ) - x = rearrange(x, "b t c -> b c t").contiguous() - x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:]) - x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache) - output = self.final_proj(x * mask_up) - return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \ - up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index e1cf429..d9e832b 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -241,7 +241,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): prompt_feat, prompt_feat_len, embedding, - cache, finalize): assert token.shape[0] == 1 # xvec projection @@ -255,16 +254,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): # text encode if finalize is True: - h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache']) + h, h_lengths = self.encoder(token, token_len) else: token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:] - h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache']) - cache['encoder_cache']['offset'] = encoder_cache[0] - cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1] - cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2] - cache['encoder_cache']['upsample_offset'] = encoder_cache[3] - cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4] - cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5] + h, h_lengths = self.encoder(token, token_len, context=context) mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] h = self.encoder_proj(h) @@ -274,14 +267,13 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): conds = conds.transpose(1, 2) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - feat, cache['decoder_cache'] = self.decoder( + feat, _ = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, cond=conds, n_timesteps=10, - cache=cache['decoder_cache'] ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 - return feat.float(), cache + return feat.float(), None diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 47e6961..735889f 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -126,21 +126,26 @@ class ConditionalCFM(BASECFM): if isinstance(self.estimator, torch.nn.Module): return self.estimator(x, mask, mu, t, spks, cond) else: - with self.lock: - self.estimator.set_input_shape('x', (2, 80, x.size(2))) - self.estimator.set_input_shape('mask', (2, 1, x.size(2))) - self.estimator.set_input_shape('mu', (2, 80, x.size(2))) - self.estimator.set_input_shape('t', (2,)) - self.estimator.set_input_shape('spks', (2, 80)) - self.estimator.set_input_shape('cond', (2, 80, x.size(2))) - # run trt engine - assert self.estimator.execute_v2([x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()]) is True + estimator, trt_engine = self.estimator.acquire_estimator() + estimator.set_input_shape('x', (2, 80, x.size(2))) + estimator.set_input_shape('mask', (2, 1, x.size(2))) + estimator.set_input_shape('mu', (2, 80, x.size(2))) + estimator.set_input_shape('t', (2,)) + estimator.set_input_shape('spks', (2, 80)) + estimator.set_input_shape('cond', (2, 80, x.size(2))) + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator) return x def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): @@ -191,7 +196,7 @@ class CausalConditionalCFM(ConditionalCFM): self.rand_noise = torch.randn([1, 80, 50 * 300]) @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): """Forward diffusion Args: @@ -210,136 +215,9 @@ class CausalConditionalCFM(ConditionalCFM): shape: (batch_size, n_feats, mel_timesteps) """ - offset = cache.pop('offset') - z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature - z = z[:, :, offset:] - offset += mu.size(2) + z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature # fix prompt and overlap part mu and z t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache) - cache['offset'] = offset - return mel, cache - - def solve_euler(self, x, t_span, mu, mask, spks, cond, cache): - """ - Fixed euler solver for ODEs. - Args: - x (torch.Tensor): random noise - t_span (torch.Tensor): n_timesteps interpolated - shape: (n_timesteps + 1,) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - """ - t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - t = t.unsqueeze(dim=0) - - # I am storing this because I can later plot it by putting a debugger here and saving it to a file - # Or in future might add like a return_all_steps flag - sol = [] - - # Do not use concat, it may cause memory format changed and trt infer with wrong results! - x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) - mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - t_in = torch.zeros([2], device=x.device, dtype=x.dtype) - spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) - cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - flow_cache_size = cache['down_blocks_kv_cache'].shape[4] - for step in range(1, len(t_span)): - # Classifier-Free Guidance inference introduced in VoiceBox - x_in[:] = x - mask_in[:] = mask - mu_in[0] = mu - t_in[:] = t.unsqueeze(0) - spks_in[0] = spks - cond_in[0] = cond - cache_step = {k: v[step - 1] for k, v in cache.items()} - dphi_dt, cache_step = self.forward_estimator( - x_in, mask_in, - mu_in, t_in, - spks_in, - cond_in, - cache_step - ) - # NOTE if smaller than flow_cache_size, means last chunk, no need to cache - if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size: - cache['down_blocks_conv_cache'][step - 1] = cache_step[0] - cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:] - cache['mid_blocks_conv_cache'][step - 1] = cache_step[2] - cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:] - cache['up_blocks_conv_cache'][step - 1] = cache_step[4] - cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:] - cache['final_blocks_conv_cache'][step - 1] = cache_step[6] - dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) - dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - return sol[-1].float(), cache - - def forward_estimator(self, x, mask, mu, t, spks, cond, cache): - if isinstance(self.estimator, torch.nn.Module): - x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache) - cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7) - else: - estimator, trt_engine = self.estimator.acquire_estimator() - estimator.set_input_shape('x', (2, 80, x.size(2))) - estimator.set_input_shape('mask', (2, 1, x.size(2))) - estimator.set_input_shape('mu', (2, 80, x.size(2))) - estimator.set_input_shape('t', (2,)) - estimator.set_input_shape('spks', (2, 80)) - estimator.set_input_shape('cond', (2, 80, x.size(2))) - estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape) - estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape) - estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape) - estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape) - estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape) - estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape) - estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape) - down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) - mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x) - up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) - data_ptrs = [x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - cache['down_blocks_conv_cache'].contiguous().data_ptr(), - cache['down_blocks_kv_cache'].contiguous().data_ptr(), - cache['mid_blocks_conv_cache'].contiguous().data_ptr(), - cache['mid_blocks_kv_cache'].contiguous().data_ptr(), - cache['up_blocks_conv_cache'].contiguous().data_ptr(), - cache['up_blocks_kv_cache'].contiguous().data_ptr(), - cache['final_blocks_conv_cache'].contiguous().data_ptr(), - x.data_ptr(), - cache['down_blocks_conv_cache'].data_ptr(), - down_blocks_kv_cache_out.data_ptr(), - cache['mid_blocks_conv_cache'].data_ptr(), - mid_blocks_kv_cache_out.data_ptr(), - cache['up_blocks_conv_cache'].data_ptr(), - up_blocks_kv_cache_out.data_ptr(), - cache['final_blocks_conv_cache'].data_ptr()] - for i, j in enumerate(data_ptrs): - estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) - # run trt engine - assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True - torch.cuda.current_stream().synchronize() - self.estimator.release_estimator(estimator) - cache = (cache['down_blocks_conv_cache'], - down_blocks_kv_cache_out, - cache['mid_blocks_conv_cache'], - mid_blocks_kv_cache_out, - cache['up_blocks_conv_cache'], - up_blocks_kv_cache_out, - cache['final_blocks_conv_cache']) - return x, cache + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index 50d7f99..326a1a7 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -223,6 +223,172 @@ class SourceModuleHnNSF(torch.nn.Module): return sine_merge, noise, uv +class SineGen2(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, upsample_scale, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen2, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.upsample_scale = upsample_scale + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2), + scale_factor=1 / self.upsample_scale, + mode="linear").transpose(1, 2) + + phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale, + scale_factor=self.upsample_scale, mode="linear").transpose(1, 2) + sines = torch.sin(phase) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF2(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF2, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + class HiFTGenerator(nn.Module): """ HiFTNet Generator: Neural Source Filter + ISTFTNet @@ -259,7 +425,9 @@ class HiFTGenerator(nn.Module): self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) - self.m_source = SourceModuleHnNSF( + # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation + this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2 + self.m_source = this_SourceModuleHnNSF( sampling_rate=sampling_rate, upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], harmonic_num=nb_harmonics, diff --git a/cosyvoice/transformer/upsample_encoder.py b/cosyvoice/transformer/upsample_encoder.py index 0d98406..e17b188 100644 --- a/cosyvoice/transformer/upsample_encoder.py +++ b/cosyvoice/transformer/upsample_encoder.py @@ -56,16 +56,11 @@ class Upsample1D(nn.Module): # In this mode, first repeat interpolate, than conv with stride=1 self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) - def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") - if conv_cache.size(2) == 0: - outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) - else: - assert conv_cache.size(2) == self.stride * 2 - outputs = torch.concat([conv_cache, outputs], dim=2) - conv_cache_new = outputs[:, :, -self.stride * 2:] + outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) outputs = self.conv(outputs) - return outputs, input_lengths * self.stride, conv_cache_new + return outputs, input_lengths * self.stride class PreLookaheadLayer(nn.Module): @@ -83,7 +78,7 @@ class PreLookaheadLayer(nn.Module): kernel_size=3, stride=1, padding=0, ) - def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor: """ inputs: (batch_size, seq_len, channels) """ @@ -93,22 +88,18 @@ class PreLookaheadLayer(nn.Module): if context.size(2) == 0: outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) else: + assert self.training is False, 'you have passed context, make sure that you are running inference mode' assert context.size(2) == self.pre_lookahead_len outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0) outputs = F.leaky_relu(self.conv1(outputs)) # outputs - if conv2_cache.size(2) == 0: - outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0) - else: - assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1 - outputs = torch.concat([conv2_cache, outputs], dim=2) - conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):] + outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0) outputs = self.conv2(outputs) outputs = outputs.transpose(1, 2).contiguous() # residual connection outputs = outputs + inputs - return outputs, conv2_cache_new + return outputs class UpsampleConformerEncoder(torch.nn.Module): @@ -253,6 +244,7 @@ class UpsampleConformerEncoder(torch.nn.Module): self, xs: torch.Tensor, xs_lens: torch.Tensor, + context: torch.Tensor = torch.zeros(0, 0, 0), decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1, streaming: bool = False, @@ -280,20 +272,27 @@ class UpsampleConformerEncoder(torch.nn.Module): checkpointing API because `__call__` attaches all the hooks of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 """ + if hasattr(self, 'streaming'): + assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode' + streaming = self.streaming T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) if self.global_cmvn is not None: xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) + if context.size(1) != 0: + assert self.training is False, 'you have passed context, make sure that you are running inference mode' + context_masks = torch.ones(1, 1, context.size(1)).to(masks) + context, _, _ = self.embed(context, context_masks, offset=xs.size(1)) mask_pad = masks # (B, 1, T/subsample_rate) chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1) # lookahead + conformer encoder - xs, _ = self.pre_lookahead_layer(xs) + xs = self.pre_lookahead_layer(xs, context=context) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) # upsample + conformer encoder xs = xs.transpose(1, 2).contiguous() - xs, xs_lens, _ = self.up_layer(xs, xs_lens) + xs, xs_lens = self.up_layer(xs, xs_lens) xs = xs.transpose(1, 2).contiguous() T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) @@ -322,100 +321,3 @@ class UpsampleConformerEncoder(torch.nn.Module): for layer in self.up_encoders: xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs - - @torch.jit.export - def forward_chunk( - self, - xs: torch.Tensor, - xs_lens: torch.Tensor, - offset: int = 0, - context: torch.Tensor = torch.zeros(0, 0, 0), - pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0), - encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0), - upsample_offset: int = 0, - upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0), - upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0) - ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]: - """Embed positions in tensor. - - Args: - xs: padded input tensor (B, T, D) - xs_lens: input length (B) - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - NOTE(xcsong): - We pass the `__call__` method of the modules instead of `forward` to the - checkpointing API because `__call__` attaches all the hooks of the module. - https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 - """ - assert xs.size(0) == 1 - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) - offset += xs.size(1) - tmp_masks = torch.ones(1, - context.size(1), - device=context.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if context.size(1) != 0: - context, _, _ = self.embed(context, tmp_masks, offset) - - # lookahead + conformer encoder - xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache) - # NOTE in cache mode we do not need to call add_optional_chunk_mask - chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device) - mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device) - encoders_kv_cache_list = [] - for index, layer in enumerate(self.encoders): - xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index]) - encoders_kv_cache_list.append(encoders_kv_cache_new) - encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0) - - # upsample - xs = xs.transpose(1, 2).contiguous() - xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache) - xs = xs.transpose(1, 2).contiguous() - - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset) - upsample_offset += xs.size(1) - - # conformer encoder - chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device) - mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device) - upsample_kv_cache_list = [] - for index, layer in enumerate(self.up_encoders): - xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index]) - upsample_kv_cache_list.append(upsample_kv_cache_new) - upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0) - - if self.normalize_before: - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index 80eafaf..ae860c9 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -56,7 +56,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB if fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() diff --git a/cosyvoice/utils/mask.py b/cosyvoice/utils/mask.py index 35dcd69..c966cc9 100644 --- a/cosyvoice/utils/mask.py +++ b/cosyvoice/utils/mask.py @@ -86,7 +86,7 @@ def subsequent_mask( return mask -def subsequent_chunk_mask( +def subsequent_chunk_mask_deprecated( size: int, chunk_size: int, num_left_chunks: int = -1, @@ -124,6 +124,40 @@ def subsequent_chunk_mask( return ret +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks + pos_idx = torch.arange(size, device=device) + block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size + ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) + return ret + + def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, use_dynamic_chunk: bool, @@ -196,9 +230,6 @@ def add_optional_chunk_mask(xs: torch.Tensor, else: chunk_masks = masks assert chunk_masks.dtype == torch.bool - if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: - print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') - chunk_masks[chunk_masks.sum(dim=-1) == 0] = True return chunk_masks diff --git a/test1.py b/test1.py deleted file mode 100644 index a1243e4..0000000 --- a/test1.py +++ /dev/null @@ -1,37 +0,0 @@ -import sys -sys.path.append('third_party/Matcha-TTS') -from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 -from cosyvoice.utils.file_utils import load_wav -import torchaudio # type: ignore - -cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False) - -# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference -# zero_shot usage -prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) -for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# save zero_shot spk for future usage -assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True -for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) -cosyvoice.save_spkinfo() - -# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248 -for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)): - torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# instruct usage -for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)): - torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# bistream usage, you can use generator as input, this is useful when using text llm model as input -# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length -def text_generator(): - yield '收到好友从远方寄来的生日礼物,' - yield '那份意外的惊喜与深深的祝福' - yield '让我心中充满了甜蜜的快乐,' - yield '笑容如花儿般绽放。' -for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) \ No newline at end of file