From 65ad448714c60fc6d4133ea3a0439e2ed5320b43 Mon Sep 17 00:00:00 2001 From: burkliu Date: Thu, 24 Apr 2025 17:14:49 +0800 Subject: [PATCH] [debug] a better solution for mismatch of speech feat len and speech token len, refer to https://github.com/FunAudioLLM/CosyVoice/issues/1051 --- cosyvoice/dataset/processor.py | 12 ++++++++++-- cosyvoice/flow/flow.py | 2 -- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 8424ada..8ac82a1 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'): def compute_fbank(data, feat_extractor, + token_mel_ratio=2, mode='train'): """ Extract fbank @@ -174,8 +175,15 @@ def compute_fbank(data, assert 'utt' in sample assert 'text_token' in sample waveform = sample['speech'] - mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) - sample['speech_feat'] = mat + feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) + + # padding with replicate mode (align to speech_token len * token_mel_ratio) + pad_len = sample["speech_token"].shape[0] * token_mel_ratio - feat.shape[0] + if pad_len > 0: + feat_to_pad = feat[-1:].repeat((pad_len, 1)) + feat = torch.cat([feat, feat_to_pad], dim=0) + + sample['speech_feat'] = feat yield sample diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 9c642ee..e1cf429 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module): mask = (~make_pad_mask(feat_len)).to(h) # NOTE this is unnecessary, feat/h already same shape - feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) loss, _ = self.decoder.compute_loss( feat.transpose(1, 2).contiguous(), mask.unsqueeze(1), @@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): h = self.encoder_proj(h) # get conditions - feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) conds = torch.zeros(feat.shape, device=token.device) for i, j in enumerate(feat_len): if random.random() < 0.5: