add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

View File

@@ -0,0 +1,96 @@
import logging
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import sentencepiece as spm
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr_local.datasets.large_datasets.dataset import Dataset
from funasr_local.iterators.abs_iter_factory import AbsIterFactory
from funasr_local.text.abs_tokenizer import AbsTokenizer
def read_symbol_table(symbol_table_file):
if isinstance(symbol_table_file, str):
symbol_table = {}
with open(symbol_table_file, "r", encoding="utf8") as fin:
for i, line in enumerate(fin):
char = line.strip()
symbol_table[char] = i
else:
assert isinstance(symbol_table_file, list)
symbol_table = {}
for i, char in enumerate(symbol_table_file):
symbol_table[char] = i
return symbol_table
def load_seg_dict(seg_dict_file):
seg_dict = {}
assert isinstance(seg_dict_file, str)
with open(seg_dict_file, "r", encoding="utf8") as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
assert check_argument_types()
self.model = str(model)
self.sp = None
def __repr__(self):
return f'{self.__class__.__name__}(model="{self.model}")'
def _build_sentence_piece_processor(self):
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.model)
def text2tokens(self, line: str) -> List[str]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
def tokens2text(self, tokens: Iterable[str]) -> str:
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
class ArkDataLoader(AbsIterFactory):
def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
bpemodel_file=None, mode="train"):
symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
if seg_dict_file is not None:
seg_dict = load_seg_dict(seg_dict_file)
else:
seg_dict = None
if punc_dict_file is not None:
punc_dict = read_symbol_table(punc_dict_file)
else:
punc_dict = None
self.dataset_conf = dataset_conf
self.frontend_conf = frontend_conf
logging.info("dataloader config: {}".format(self.dataset_conf))
batch_mode = self.dataset_conf.get("batch_mode", "padding")
if bpemodel_file is not None:
bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
else:
bpe_tokenizer = None
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True):
self.dataset.set_epoch(epoch)
data_loader = DataLoader(self.dataset,
batch_size=None,
pin_memory=True,
num_workers=self.dataset_conf.get("num_workers", 8))
return data_loader

View File

@@ -0,0 +1,213 @@
import random
from itertools import count
from functools import partial
from torch.utils.data import IterableDataset
from funasr_local.datasets.large_datasets.datapipes.map import MapperIterDataPipe
tiebreaker = count()
def _default_len_fn(token):
return len(token), next(tiebreaker)
def _token_len_fn(token, len_fn):
return len_fn(token), next(tiebreaker), token
class MaxTokenBucketizerIterDataPipe(IterableDataset):
def __init__(
self,
datapipe,
batch_size=8000,
len_fn=_default_len_fn,
buffer_size=10240,
sort_size=500,
batch_mode="padding",
):
assert batch_size > 0, "Batch size is required to be larger than 0!"
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
assert sort_size > 0, "Sort size is required to be larger than 0!"
datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
self.datapipe = datapipe
self.batch_size = batch_size
self.buffer_size = buffer_size
self.sort_size = sort_size
self.batch_mode = batch_mode
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
buffer = []
batch = []
bucket = []
max_lengths = 0
min_lengths = 999999
batch_lengths = 0
if self.batch_mode == "clipping":
assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
buffer = []
if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
buffer = []
if bucket:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
if batch:
yield batch
else:
if self.buffer_size == -1:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
buffer.sort()
for sample in buffer:
length, _, token = sample
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
bucket.append(batch)
batch = []
max_lengths = length
batch.append(token)
random.shuffle(bucket)
if bucket:
for batch_sample in bucket:
yield batch_sample
if batch:
yield batch
elif self.buffer_size == 0:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
length, _, token = d
if length > self.batch_size:
continue
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
if batch:
yield batch
else:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
buffer = []
if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
buffer = []
if bucket:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
if batch:
yield batch

View File

@@ -0,0 +1,24 @@
from torch.utils.data import IterableDataset
def default_fn(data):
return data
class FilterIterDataPipe(IterableDataset):
def __init__(self,
datapipe,
fn=default_fn):
self.datapipe = datapipe
self.fn = fn
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
assert callable(self.fn)
for data in self.datapipe:
if self.fn(data):
yield data
else:
continue

View File

@@ -0,0 +1,22 @@
from torch.utils.data import IterableDataset
def default_fn(data):
return data
class MapperIterDataPipe(IterableDataset):
def __init__(self,
datapipe,
fn=default_fn):
self.datapipe = datapipe
self.fn = fn
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
assert callable(self.fn)
for data in self.datapipe:
yield self.fn(data)

View File

@@ -0,0 +1,212 @@
import os
import random
import numpy
from functools import partial
import torch
import torchaudio
import torch.distributed as dist
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr_local.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr_local.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr_local.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr_local.datasets.large_datasets.utils.filter import filter
from funasr_local.datasets.large_datasets.utils.padding import padding
from funasr_local.datasets.large_datasets.utils.clipping import clipping
from funasr_local.datasets.large_datasets.utils.tokenize import tokenize
def read_lists(list_file):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
parts = line.strip()
lists.append(parts)
return lists
class AudioDataset(IterableDataset):
def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
self.scp_lists = scp_lists
self.data_names = data_names
self.data_types = data_types
self.frontend_conf = frontend_conf
self.shuffle = shuffle
self.mode = mode
self.epoch = -1
self.rank = 0
self.world_size = 1
self.worker_id = 0
self.num_workers = 1
def set_epoch(self, epoch):
self.epoch = epoch
def get_rank_data_list(self, data_index):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
if self.mode == "train":
if self.shuffle:
random.seed(self.epoch)
random.shuffle(data_index)
return data_index[self.rank::self.world_size]
return data_index
def get_worker_data_list(self, rank_data_index):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return rank_data_index[self.worker_id::self.num_workers]
def close_reader(self, reader_list):
for reader in reader_list:
reader.close()
def __iter__(self):
data_index = list(range(len(self.scp_lists)))
rank_data_index = self.get_rank_data_list(data_index)
worker_data_index = self.get_worker_data_list(rank_data_index)
for index in worker_data_index:
data = dict(scp=self.scp_lists[index])
assert 'scp' in data
scp = data['scp']
data_file_list = scp.strip().split()
data_name_list = self.data_names.split(",")
data_type_list = self.data_types.split(",")
for file in data_file_list:
assert os.path.exists(file), "{} not exists".format(file)
assert len(data_file_list) == len(data_name_list) == len(data_type_list), \
"The item number of data, data_names, data_types must be the same "
reader_list = []
for data_file, data_type in zip(data_file_list, data_type_list):
if data_type == "kaldi_ark":
ark_reader = ReadHelper('ark:{}'.format(data_file))
reader_list.append(ark_reader)
elif data_type == "text" or data_type == "sound":
text_reader = open(data_file, "r")
reader_list.append(text_reader)
elif data_type == "none":
continue
else:
raise TypeError("Data type {} is not supported".format(data_type))
for items in zip(*reader_list):
sample_dict = {}
for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
if data_type == "kaldi_ark":
key, mat = item
sample_dict[data_name] = mat
if data_name == "speech":
sample_dict["key"] = key
elif data_type == "sound":
key, path = item.strip().split()
waveform, sampling_rate = torchaudio.load(path)
if self.frontend_conf is not None:
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend_conf["fs"])(waveform)
sampling_rate = self.frontend_conf["fs"]
waveform = waveform.numpy()
mat = waveform[0]
sample_dict[data_name] = mat
sample_dict["sampling_rate"] = sampling_rate
if data_name == "speech":
sample_dict["key"] = key
else:
text = item
segs = text.strip().split()
sample_dict[data_name] = segs[1:]
if "key" not in sample_dict:
sample_dict["key"] = segs[0]
yield sample_dict
self.close_reader(reader_list)
def len_fn_example(data):
return 1
def len_fn_token(data):
assert "speech" in data
if "sampling_rate" in data:
return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
else:
return data["speech"].shape[0]
def Dataset(data_list_file,
dict,
seg_dict,
punc_dict,
bpe_tokenizer,
conf,
frontend_conf,
mode="train",
batch_mode="padding"):
scp_lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
data_names = conf.get("data_names", "speech,text")
data_types = conf.get("data_types", "kaldi_ark,text")
dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode)
filter_conf = conf.get('filter_conf', {})
filter_fn = partial(filter, **filter_conf)
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
if "text" in data_names:
vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
tokenize_fn = partial(tokenize, **vocab)
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
if shuffle:
buffer_conf = conf.get('shuffle_conf', {})
buffer_size = buffer_conf['shuffle_size']
sort_size = buffer_conf['sort_size']
else:
buffer_size = 0
sort_size = 1
batch_conf = conf.get('batch_conf', {})
batch_size = batch_conf['batch_size']
batch_type = batch_conf['batch_type']
assert batch_type in ["example", "token"]
if batch_type == 'example':
len_fn = len_fn_example
else:
len_fn = len_fn_token
dataset = MaxTokenBucketizerIterDataPipe(dataset,
batch_size=batch_size,
len_fn=len_fn,
buffer_size=buffer_size,
sort_size=sort_size,
batch_mode=batch_mode)
int_pad_value = conf.get("int_pad_value", -1)
float_pad_value = conf.get("float_pad_value", 0.0)
padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
padding_fn = partial(padding, **padding_conf)
dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
return dataset

View File

@@ -0,0 +1,40 @@
import numpy as np
import torch
from funasr_local.datasets.collate_fn import crop_to_max_size
def clipping(data):
assert isinstance(data, list)
assert "key" in data[0]
keys = [x["key"] for x in data]
batch = {}
data_names = data[0].keys()
for data_name in data_names:
if data_name == "key":
continue
else:
if data[0][data_name].dtype.kind == "i":
tensor_type = torch.int64
else:
tensor_type = torch.float32
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
length_clip = min(tensor_lengths)
tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
diff = length - length_clip
assert diff >= 0
if diff == 0:
tensor_clip[i] = tensor
else:
tensor_clip[i] = crop_to_max_size(tensor, length_clip)
batch[data_name] = tensor_clip
batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
return keys, batch

View File

@@ -0,0 +1,26 @@
#!/usr/bin/env python
def filter(data,
speech_length_min=100,
speech_length_max=15000,
token_length_min=0,
token_length_max=200):
assert "speech" in data or "text" in data
if "speech" in data and "text" in data:
if "sampling_rate" in data:
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
else:
speech_length = data["speech"].shape[0]
num_tokens = len(data['text'])
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
elif "speech" in data:
if "sampling_rate" in data:
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
else:
speech_length = data["speech"].shape[0]
return speech_length_min < speech_length < speech_length_max
else:
num_tokens = len(data['text'])
return token_length_min < num_tokens < token_length_max

View File

@@ -0,0 +1,30 @@
import numpy as np
def build_LFR_features(data, m, n):
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
Args:
inputs_batch: inputs is T x D np.ndarray
m: number of frames to stack
n: number of frames to skip
"""
LFR_inputs = []
T = data.shape[0]
T_lfr = int(np.ceil(T / n))
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(data[i*n:i*n+m]))
else:
num_padding = m - (T - i * n)
frame = np.hstack(data[i*n:])
for _ in range(num_padding):
frame = np.hstack((frame, data[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)

View File

@@ -0,0 +1,34 @@
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
def padding(data, float_pad_value=0.0, int_pad_value=-1):
assert isinstance(data, list)
assert "key" in data[0]
assert "speech" in data[0] or "text" in data[0]
keys = [x["key"] for x in data]
batch = {}
data_names = data[0].keys()
for data_name in data_names:
if data_name == "key" or data_name =="sampling_rate":
continue
else:
if data[0][data_name].dtype.kind == "i":
pad_value = int_pad_value
tensor_type = torch.int64
else:
pad_value = float_pad_value
tensor_type = torch.float32
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
tensor_pad = pad_sequence(tensor_list,
batch_first=True,
padding_value=pad_value)
batch[data_name] = tensor_pad
batch[data_name + "_lengths"] = tensor_lengths
return keys, batch

View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python
import re
import numpy as np
def forward_segment(text, seg_dict):
word_list = []
i = 0
while i < len(text):
longest_word = text[i]
for j in range(i + 1, len(text) + 1):
word = text[i:j]
if word in seg_dict:
if len(word) > len(longest_word):
longest_word = word
word_list.append(longest_word)
i += len(longest_word)
return word_list
def seg_tokenize(txt, seg_dict):
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
out_txt = ""
for word in txt:
word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
if pattern.match(word):
for char in word:
if char in seg_dict:
out_txt += seg_dict[char] + " "
else:
out_txt += "<unk>" + " "
else:
out_txt += "<unk>" + " "
return out_txt.strip().split()
def tokenize(data,
vocab=None,
seg_dict=None,
punc_dict=None,
bpe_tokenizer=None):
assert "text" in data
assert isinstance(vocab, dict)
text = data["text"]
token = []
vad = -2
if bpe_tokenizer is not None:
text = bpe_tokenizer.text2tokens("".join(text))
if seg_dict is not None:
assert isinstance(seg_dict, dict)
text = seg_tokenize(text, seg_dict)
length = len(text)
for i in range(length):
x = text[i]
if i == length-1 and "punc" in data and x.startswith("vad:"):
vad = x[4:]
if len(vad) == 0:
vad = -1
else:
vad = int(vad)
elif x in vocab:
token.append(vocab[x])
else:
token.append(vocab['<unk>'])
if "punc" in data and punc_dict is not None:
punc_token = []
for punc in data["punc"]:
if punc in punc_dict:
punc_token.append(punc_dict[punc])
else:
punc_token.append(punc_dict["_"])
data["punc"] = np.array(punc_token)
data["text"] = np.array(token)
if vad is not -2:
data["vad_indexes"]=np.array([vad], dtype=np.int64)
return data