mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
online feature
This commit is contained in:
@@ -16,12 +16,13 @@ import random
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
import pyworld as pw
|
||||
|
||||
from cosyvoice.utils.onnx import embedding_extractor, online_feature
|
||||
|
||||
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
||||
|
||||
@@ -92,9 +93,9 @@ def filter(data,
|
||||
continue
|
||||
if len(sample['text_token']) > token_max_length:
|
||||
continue
|
||||
if len(sample['speech_token']) == 0:
|
||||
if online_feature is False and len(sample['speech_token']) == 0:
|
||||
continue
|
||||
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
continue
|
||||
if num_frames != 0:
|
||||
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
||||
@@ -155,7 +156,7 @@ def truncate(data, truncate_length=24576, mode='train'):
|
||||
|
||||
def compute_fbank(data,
|
||||
feat_extractor,
|
||||
token_mel_ratio=0,
|
||||
num_frames=-1,
|
||||
mode='train'):
|
||||
""" Extract fbank
|
||||
|
||||
@@ -170,14 +171,11 @@ def compute_fbank(data,
|
||||
assert 'speech' in sample
|
||||
assert 'utt' in sample
|
||||
assert 'text_token' in sample
|
||||
waveform = sample['speech']
|
||||
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
||||
if token_mel_ratio != 0:
|
||||
# trim to align speech_token and speech_feat
|
||||
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
|
||||
feat = feat[:token_mel_ratio * token_len]
|
||||
sample["speech_token"] = sample["speech_token"][:token_len]
|
||||
sample['speech_feat'] = feat
|
||||
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
|
||||
if num_frames != -1:
|
||||
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
|
||||
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
|
||||
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
|
||||
yield sample
|
||||
|
||||
|
||||
@@ -216,6 +214,10 @@ def parse_embedding(data, normalize, mode='train'):
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
|
||||
speech_16k = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
|
||||
embedding = embedding_extractor.inference(speech_16k)
|
||||
sample['spk_embedding'] = sample['utt_embedding'] = embedding
|
||||
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
||||
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
||||
if normalize:
|
||||
@@ -256,13 +258,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
buf = []
|
||||
yield_size = int(shuffle_size / 2)
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= shuffle_size:
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
for x in buf[:yield_size]:
|
||||
yield x
|
||||
buf = []
|
||||
buf = buf[yield_size:]
|
||||
# The sample left over
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
@@ -420,10 +423,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||
padding_value=0)
|
||||
batch["pitch_feat"] = pitch_feat
|
||||
batch["pitch_feat_len"] = pitch_feat_len
|
||||
else:
|
||||
# only gan train needs speech, delete it to save memory
|
||||
del batch["speech"]
|
||||
del batch["speech_len"]
|
||||
if dpo is True:
|
||||
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
|
||||
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user