online feature

This commit is contained in:
lyuxiang.lx
2026-01-28 15:19:07 +00:00
parent 1822c5c908
commit 66b80dbccb
22 changed files with 133 additions and 116 deletions

View File

@@ -16,12 +16,13 @@ import random
import pyarrow.parquet as pq
from io import BytesIO
import numpy as np
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
from cosyvoice.utils.onnx import embedding_extractor, online_feature
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
@@ -92,9 +93,9 @@ def filter(data,
continue
if len(sample['text_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
if online_feature is False and len(sample['speech_token']) == 0:
continue
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
@@ -155,7 +156,7 @@ def truncate(data, truncate_length=24576, mode='train'):
def compute_fbank(data,
feat_extractor,
token_mel_ratio=0,
num_frames=-1,
mode='train'):
""" Extract fbank
@@ -170,14 +171,11 @@ def compute_fbank(data,
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
if token_mel_ratio != 0:
# trim to align speech_token and speech_feat
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
feat = feat[:token_mel_ratio * token_len]
sample["speech_token"] = sample["speech_token"][:token_len]
sample['speech_feat'] = feat
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
if num_frames != -1:
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
yield sample
@@ -216,6 +214,10 @@ def parse_embedding(data, normalize, mode='train'):
Iterable[{key, feat, label}]
"""
for sample in data:
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
speech_16k = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
embedding = embedding_extractor.inference(speech_16k)
sample['spk_embedding'] = sample['utt_embedding'] = embedding
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize:
@@ -256,13 +258,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
Iterable[{key, feat, label}]
"""
buf = []
yield_size = int(shuffle_size / 2)
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
for x in buf[:yield_size]:
yield x
buf = []
buf = buf[yield_size:]
# The sample left over
random.shuffle(buf)
for x in buf:
@@ -420,10 +423,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
padding_value=0)
batch["pitch_feat"] = pitch_feat
batch["pitch_feat_len"] = pitch_feat_len
else:
# only gan train needs speech, delete it to save memory
del batch["speech"]
del batch["speech_len"]
if dpo is True:
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)