fix cache bug

This commit is contained in:
lyuxiang.lx
2025-01-24 11:07:26 +08:00
parent 1c062ab381
commit aea75207dd
2 changed files with 20 additions and 22 deletions

View File

@@ -396,6 +396,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
encoders_kv_cache_list = [] encoders_kv_cache_list = []
for index, layer in enumerate(self.encoders): for index, layer in enumerate(self.encoders):
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index]) xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
encoders_kv_cache_list.append(encoders_kv_cache_new)
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0) encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
# upsample # upsample
@@ -426,4 +427,4 @@ class UpsampleConformerEncoder(torch.nn.Module):
# Here we assume the mask is not changed in encoder layers, so just # Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used # return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later # for cross attention with decoder later
return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache_new, upsample_offset, upsample_conv_cache, upsample_kv_cache_new) return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)

View File

@@ -56,7 +56,7 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
input_size: 512 input_size: 512
use_cnn_module: False use_cnn_module: False
macaron_style: False macaron_style: False
use_dynamic_chunk: True static_chunk_size: !ref <token_frame_rate> # 试试UpsampleConformerEncoder也是static
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
in_channels: 240 in_channels: 240
n_spks: 1 n_spks: 1
@@ -154,12 +154,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
center: False center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor> feat_extractor: !ref <feat_extractor>
# pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch # TODO need to replace it compute_f0: !name:cosyvoice.dataset.processor.compute_f0
# sample_rate: !ref <sample_rate> sample_rate: !ref <sample_rate>
# frame_length: 46.4 # match feat_extractor win_size/sampling_rate hop_size: 480
# frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate
# compute_f0: !name:cosyvoice.dataset.processor.compute_f0
# pitch_extractor: !ref <pitch_extractor>
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle shuffle: !name:cosyvoice.dataset.processor.shuffle
@@ -186,20 +183,20 @@ data_pipeline: [
!ref <batch>, !ref <batch>,
!ref <padding>, !ref <padding>,
] ]
# data_pipeline_gan: [ data_pipeline_gan: [
# !ref <parquet_opener>, !ref <parquet_opener>,
# !ref <tokenize>, !ref <tokenize>,
# !ref <filter>, !ref <filter>,
# !ref <resample>, !ref <resample>,
# !ref <truncate>, !ref <truncate>,
# !ref <compute_fbank>, !ref <compute_fbank>,
# !ref <compute_f0>, !ref <compute_f0>,
# !ref <parse_embedding>, !ref <parse_embedding>,
# !ref <shuffle>, !ref <shuffle>,
# !ref <sort>, !ref <sort>,
# !ref <batch>, !ref <batch>,
# !ref <padding>, !ref <padding>,
# ] ]
# llm flow train conf # llm flow train conf
train_conf: train_conf: