mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
[debug] a better solution for mismatch of speech feat len and speech token len, refer to https://github.com/FunAudioLLM/CosyVoice/issues/1051
This commit is contained in:
@@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
|
|||||||
|
|
||||||
def compute_fbank(data,
|
def compute_fbank(data,
|
||||||
feat_extractor,
|
feat_extractor,
|
||||||
|
token_mel_ratio=2,
|
||||||
mode='train'):
|
mode='train'):
|
||||||
""" Extract fbank
|
""" Extract fbank
|
||||||
|
|
||||||
@@ -174,8 +175,15 @@ def compute_fbank(data,
|
|||||||
assert 'utt' in sample
|
assert 'utt' in sample
|
||||||
assert 'text_token' in sample
|
assert 'text_token' in sample
|
||||||
waveform = sample['speech']
|
waveform = sample['speech']
|
||||||
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
||||||
sample['speech_feat'] = mat
|
|
||||||
|
# padding with replicate mode (align to speech_token len * token_mel_ratio)
|
||||||
|
pad_len = sample["speech_token"].shape[0] * token_mel_ratio - feat.shape[0]
|
||||||
|
if pad_len > 0:
|
||||||
|
feat_to_pad = feat[-1:].repeat((pad_len, 1))
|
||||||
|
feat = torch.cat([feat, feat_to_pad], dim=0)
|
||||||
|
|
||||||
|
sample['speech_feat'] = feat
|
||||||
yield sample
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
|
|
||||||
mask = (~make_pad_mask(feat_len)).to(h)
|
mask = (~make_pad_mask(feat_len)).to(h)
|
||||||
# NOTE this is unnecessary, feat/h already same shape
|
# NOTE this is unnecessary, feat/h already same shape
|
||||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
|
||||||
loss, _ = self.decoder.compute_loss(
|
loss, _ = self.decoder.compute_loss(
|
||||||
feat.transpose(1, 2).contiguous(),
|
feat.transpose(1, 2).contiguous(),
|
||||||
mask.unsqueeze(1),
|
mask.unsqueeze(1),
|
||||||
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
h = self.encoder_proj(h)
|
h = self.encoder_proj(h)
|
||||||
|
|
||||||
# get conditions
|
# get conditions
|
||||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
|
||||||
conds = torch.zeros(feat.shape, device=token.device)
|
conds = torch.zeros(feat.shape, device=token.device)
|
||||||
for i, j in enumerate(feat_len):
|
for i, j in enumerate(feat_len):
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
|
|||||||
Reference in New Issue
Block a user