mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/datasets/__init__.py
Normal file
0
funasr_local/datasets/__init__.py
Normal file
135
funasr_local/datasets/collate_fn.py
Normal file
135
funasr_local/datasets/collate_fn.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr_local.modules.nets_utils import pad_list
|
||||
|
||||
|
||||
class CommonCollateFn:
|
||||
"""Functor class of common_collate_fn()"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
float_pad_value: Union[float, int] = 0.0,
|
||||
int_pad_value: int = -32768,
|
||||
not_sequence: Collection[str] = (),
|
||||
max_sample_size=None
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.float_pad_value = float_pad_value
|
||||
self.int_pad_value = int_pad_value
|
||||
self.not_sequence = set(not_sequence)
|
||||
self.max_sample_size = max_sample_size
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
|
||||
f"int_pad_value={self.float_pad_value})"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
return common_collate_fn(
|
||||
data,
|
||||
float_pad_value=self.float_pad_value,
|
||||
int_pad_value=self.int_pad_value,
|
||||
not_sequence=self.not_sequence,
|
||||
)
|
||||
|
||||
|
||||
def common_collate_fn(
|
||||
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||
float_pad_value: Union[float, int] = 0.0,
|
||||
int_pad_value: int = -32768,
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||
assert all(
|
||||
not k.endswith("_lengths") for k in data[0]
|
||||
), f"*_lengths is reserved: {list(data[0])}"
|
||||
|
||||
output = {}
|
||||
for key in data[0]:
|
||||
if data[0][key].dtype.kind == "i":
|
||||
pad_value = int_pad_value
|
||||
else:
|
||||
pad_value = float_pad_value
|
||||
|
||||
array_list = [d[key] for d in data]
|
||||
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||
tensor = pad_list(tensor_list, pad_value)
|
||||
output[key] = tensor
|
||||
|
||||
if key not in not_sequence:
|
||||
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
|
||||
def crop_to_max_size(feature, target_size):
|
||||
size = len(feature)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return feature
|
||||
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return feature[start:end]
|
||||
|
||||
|
||||
def clipping_collate_fn(
|
||||
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||
max_sample_size=None,
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
# mainly for pre-training
|
||||
assert check_argument_types()
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||
assert all(
|
||||
not k.endswith("_lengths") for k in data[0]
|
||||
), f"*_lengths is reserved: {list(data[0])}"
|
||||
|
||||
output = {}
|
||||
for key in data[0]:
|
||||
array_list = [d[key] for d in data]
|
||||
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||
sizes = [len(s) for s in tensor_list]
|
||||
if max_sample_size is None:
|
||||
target_size = min(sizes)
|
||||
else:
|
||||
target_size = min(min(sizes), max_sample_size)
|
||||
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
|
||||
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
|
||||
diff = size - target_size
|
||||
if diff == 0:
|
||||
tensor[i] = source
|
||||
else:
|
||||
tensor[i] = crop_to_max_size(source, target_size)
|
||||
output[key] = tensor
|
||||
|
||||
if key not in not_sequence:
|
||||
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
448
funasr_local/datasets/dataset.py
Normal file
448
funasr_local/datasets/dataset.py
Normal file
@@ -0,0 +1,448 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
import numbers
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import Mapping
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import h5py
|
||||
import humanfriendly
|
||||
# import kaldiio
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr_local.fileio.npy_scp import NpyScpReader
|
||||
from funasr_local.fileio.rand_gen_dataset import FloatRandomGenerateDataset
|
||||
from funasr_local.fileio.rand_gen_dataset import IntRandomGenerateDataset
|
||||
from funasr_local.fileio.read_text import load_num_sequence_text
|
||||
from funasr_local.fileio.read_text import read_2column_text
|
||||
from funasr_local.fileio.sound_scp import SoundScpReader
|
||||
from funasr_local.utils.sized_dict import SizedDict
|
||||
|
||||
|
||||
class AdapterForSoundScpReader(collections.abc.Mapping):
|
||||
def __init__(self, loader, dtype=None):
|
||||
assert check_argument_types()
|
||||
self.loader = loader
|
||||
self.dtype = dtype
|
||||
self.rate = None
|
||||
|
||||
def keys(self):
|
||||
return self.loader.keys()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.loader)
|
||||
|
||||
def __getitem__(self, key: str) -> np.ndarray:
|
||||
retval = self.loader[key]
|
||||
|
||||
if isinstance(retval, tuple):
|
||||
assert len(retval) == 2, len(retval)
|
||||
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
|
||||
# sound scp case
|
||||
rate, array = retval
|
||||
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
|
||||
# Extended ark format case
|
||||
array, rate = retval
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
|
||||
)
|
||||
|
||||
if self.rate is not None and self.rate != rate:
|
||||
raise RuntimeError(
|
||||
f"Sampling rates are mismatched: {self.rate} != {rate}"
|
||||
)
|
||||
self.rate = rate
|
||||
# Multichannel wave fie
|
||||
# array: (NSample, Channel) or (Nsample)
|
||||
if self.dtype is not None:
|
||||
array = array.astype(self.dtype)
|
||||
|
||||
else:
|
||||
# Normal ark case
|
||||
assert isinstance(retval, np.ndarray), type(retval)
|
||||
array = retval
|
||||
if self.dtype is not None:
|
||||
array = array.astype(self.dtype)
|
||||
|
||||
assert isinstance(array, np.ndarray), type(array)
|
||||
return array
|
||||
|
||||
|
||||
class H5FileWrapper:
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
self.h5_file = h5py.File(path, "r")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.h5_file)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.h5_file)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.h5_file)
|
||||
|
||||
def __getitem__(self, key) -> np.ndarray:
|
||||
value = self.h5_file[key]
|
||||
return value[()]
|
||||
|
||||
|
||||
def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
|
||||
# The file is as follows:
|
||||
# utterance_id_A /some/where/a.wav
|
||||
# utterance_id_B /some/where/a.flac
|
||||
|
||||
# NOTE(kamo): SoundScpReader doesn't support pipe-fashion
|
||||
# like Kaldi e.g. "cat a.wav |".
|
||||
# NOTE(kamo): The audio signal is normalized to [-1,1] range.
|
||||
loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
|
||||
|
||||
# SoundScpReader.__getitem__() returns Tuple[int, ndarray],
|
||||
# but ndarray is desired, so Adapter class is inserted here
|
||||
return AdapterForSoundScpReader(loader, float_dtype)
|
||||
|
||||
|
||||
def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
|
||||
loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
|
||||
return AdapterForSoundScpReader(loader, float_dtype)
|
||||
|
||||
|
||||
def rand_int_loader(filepath, loader_type):
|
||||
# e.g. rand_int_3_10
|
||||
try:
|
||||
low, high = map(int, loader_type[len("rand_int_") :].split("_"))
|
||||
except ValueError:
|
||||
raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
|
||||
return IntRandomGenerateDataset(filepath, low, high)
|
||||
|
||||
|
||||
DATA_TYPES = {
|
||||
"sound": dict(
|
||||
func=sound_loader,
|
||||
kwargs=["dest_sample_rate","float_dtype"],
|
||||
help="Audio format types which supported by sndfile wav, flac, etc."
|
||||
"\n\n"
|
||||
" utterance_id_a a.wav\n"
|
||||
" utterance_id_b b.wav\n"
|
||||
" ...",
|
||||
),
|
||||
"kaldi_ark": dict(
|
||||
func=kaldi_loader,
|
||||
kwargs=["max_cache_fd"],
|
||||
help="Kaldi-ark file type."
|
||||
"\n\n"
|
||||
" utterance_id_A /some/where/a.ark:123\n"
|
||||
" utterance_id_B /some/where/a.ark:456\n"
|
||||
" ...",
|
||||
),
|
||||
"npy": dict(
|
||||
func=NpyScpReader,
|
||||
kwargs=[],
|
||||
help="Npy file format."
|
||||
"\n\n"
|
||||
" utterance_id_A /some/where/a.npy\n"
|
||||
" utterance_id_B /some/where/b.npy\n"
|
||||
" ...",
|
||||
),
|
||||
"text_int": dict(
|
||||
func=functools.partial(load_num_sequence_text, loader_type="text_int"),
|
||||
kwargs=[],
|
||||
help="A text file in which is written a sequence of interger numbers "
|
||||
"separated by space."
|
||||
"\n\n"
|
||||
" utterance_id_A 12 0 1 3\n"
|
||||
" utterance_id_B 3 3 1\n"
|
||||
" ...",
|
||||
),
|
||||
"csv_int": dict(
|
||||
func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
|
||||
kwargs=[],
|
||||
help="A text file in which is written a sequence of interger numbers "
|
||||
"separated by comma."
|
||||
"\n\n"
|
||||
" utterance_id_A 100,80\n"
|
||||
" utterance_id_B 143,80\n"
|
||||
" ...",
|
||||
),
|
||||
"text_float": dict(
|
||||
func=functools.partial(load_num_sequence_text, loader_type="text_float"),
|
||||
kwargs=[],
|
||||
help="A text file in which is written a sequence of float numbers "
|
||||
"separated by space."
|
||||
"\n\n"
|
||||
" utterance_id_A 12. 3.1 3.4 4.4\n"
|
||||
" utterance_id_B 3. 3.12 1.1\n"
|
||||
" ...",
|
||||
),
|
||||
"csv_float": dict(
|
||||
func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
|
||||
kwargs=[],
|
||||
help="A text file in which is written a sequence of float numbers "
|
||||
"separated by comma."
|
||||
"\n\n"
|
||||
" utterance_id_A 12.,3.1,3.4,4.4\n"
|
||||
" utterance_id_B 3.,3.12,1.1\n"
|
||||
" ...",
|
||||
),
|
||||
"text": dict(
|
||||
func=read_2column_text,
|
||||
kwargs=[],
|
||||
help="Return text as is. The text must be converted to ndarray "
|
||||
"by 'preprocess'."
|
||||
"\n\n"
|
||||
" utterance_id_A hello world\n"
|
||||
" utterance_id_B foo bar\n"
|
||||
" ...",
|
||||
),
|
||||
"hdf5": dict(
|
||||
func=H5FileWrapper,
|
||||
kwargs=[],
|
||||
help="A HDF5 file which contains arrays at the first level or the second level."
|
||||
" >>> f = h5py.File('file.h5')\n"
|
||||
" >>> array1 = f['utterance_id_A']\n"
|
||||
" >>> array2 = f['utterance_id_B']\n",
|
||||
),
|
||||
"rand_float": dict(
|
||||
func=FloatRandomGenerateDataset,
|
||||
kwargs=[],
|
||||
help="Generate random float-ndarray which has the given shapes "
|
||||
"in the file."
|
||||
"\n\n"
|
||||
" utterance_id_A 3,4\n"
|
||||
" utterance_id_B 10,4\n"
|
||||
" ...",
|
||||
),
|
||||
"rand_int_\\d+_\\d+": dict(
|
||||
func=rand_int_loader,
|
||||
kwargs=["loader_type"],
|
||||
help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
|
||||
"shapes in the path. "
|
||||
"Give the lower and upper value by the file type. e.g. "
|
||||
"rand_int_0_10 -> Generate integers from 0 to 10."
|
||||
"\n\n"
|
||||
" utterance_id_A 3,4\n"
|
||||
" utterance_id_B 10,4\n"
|
||||
" ...",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class AbsDataset(Dataset, ABC):
|
||||
@abstractmethod
|
||||
def has_name(self, name) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ESPnetDataset(AbsDataset):
|
||||
"""Pytorch Dataset class for ESPNet.
|
||||
|
||||
Examples:
|
||||
>>> dataset = ESPnetDataset([('wav.scp', 'input', 'sound'),
|
||||
... ('token_int', 'output', 'text_int')],
|
||||
... )
|
||||
... uttid, data = dataset['uttid']
|
||||
{'input': per_utt_array, 'output': per_utt_array}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path_name_type_list: Collection[Tuple[str, str, str]],
|
||||
preprocess: Callable[
|
||||
[str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
|
||||
] = None,
|
||||
float_dtype: str = "float32",
|
||||
int_dtype: str = "long",
|
||||
max_cache_size: Union[float, int, str] = 0.0,
|
||||
max_cache_fd: int = 0,
|
||||
dest_sample_rate: int = 16000,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"'
|
||||
)
|
||||
|
||||
path_name_type_list = copy.deepcopy(path_name_type_list)
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.float_dtype = float_dtype
|
||||
self.int_dtype = int_dtype
|
||||
self.max_cache_fd = max_cache_fd
|
||||
self.dest_sample_rate = dest_sample_rate
|
||||
|
||||
self.loader_dict = {}
|
||||
self.debug_info = {}
|
||||
for path, name, _type in path_name_type_list:
|
||||
if name in self.loader_dict:
|
||||
raise RuntimeError(f'"{name}" is duplicated for data-key')
|
||||
|
||||
loader = self._build_loader(path, _type)
|
||||
self.loader_dict[name] = loader
|
||||
self.debug_info[name] = path, _type
|
||||
if len(self.loader_dict[name]) == 0:
|
||||
raise RuntimeError(f"{path} has no samples")
|
||||
|
||||
# TODO(kamo): Should check consistency of each utt-keys?
|
||||
|
||||
if isinstance(max_cache_size, str):
|
||||
max_cache_size = humanfriendly.parse_size(max_cache_size)
|
||||
self.max_cache_size = max_cache_size
|
||||
if max_cache_size > 0:
|
||||
self.cache = SizedDict(shared=True)
|
||||
else:
|
||||
self.cache = None
|
||||
|
||||
def _build_loader(
|
||||
self, path: str, loader_type: str
|
||||
) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
|
||||
"""Helper function to instantiate Loader.
|
||||
|
||||
Args:
|
||||
path: The file path
|
||||
loader_type: loader_type. sound, npy, text_int, text_float, etc
|
||||
"""
|
||||
for key, dic in DATA_TYPES.items():
|
||||
# e.g. loader_type="sound"
|
||||
# -> return DATA_TYPES["sound"]["func"](path)
|
||||
if re.match(key, loader_type):
|
||||
kwargs = {}
|
||||
for key2 in dic["kwargs"]:
|
||||
if key2 == "loader_type":
|
||||
kwargs["loader_type"] = loader_type
|
||||
elif key2 == "dest_sample_rate" and loader_type=="sound":
|
||||
kwargs["dest_sample_rate"] = self.dest_sample_rate
|
||||
elif key2 == "float_dtype":
|
||||
kwargs["float_dtype"] = self.float_dtype
|
||||
elif key2 == "int_dtype":
|
||||
kwargs["int_dtype"] = self.int_dtype
|
||||
elif key2 == "max_cache_fd":
|
||||
kwargs["max_cache_fd"] = self.max_cache_fd
|
||||
else:
|
||||
raise RuntimeError(f"Not implemented keyword argument: {key2}")
|
||||
|
||||
func = dic["func"]
|
||||
try:
|
||||
return func(path, **kwargs)
|
||||
except Exception:
|
||||
if hasattr(func, "__name__"):
|
||||
name = func.__name__
|
||||
else:
|
||||
name = str(func)
|
||||
logging.error(f"An error happened with {name}({path})")
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError(f"Not supported: loader_type={loader_type}")
|
||||
|
||||
def has_name(self, name) -> bool:
|
||||
return name in self.loader_dict
|
||||
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
return tuple(self.loader_dict)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(next(iter(self.loader_dict.values())))
|
||||
|
||||
def __repr__(self):
|
||||
_mes = self.__class__.__name__
|
||||
_mes += "("
|
||||
for name, (path, _type) in self.debug_info.items():
|
||||
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
|
||||
_mes += f"\n preprocess: {self.preprocess})"
|
||||
return _mes
|
||||
|
||||
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
|
||||
# Change integer-id to string-id
|
||||
if isinstance(uid, int):
|
||||
d = next(iter(self.loader_dict.values()))
|
||||
uid = list(d)[uid]
|
||||
|
||||
if self.cache is not None and uid in self.cache:
|
||||
data = self.cache[uid]
|
||||
return uid, data
|
||||
|
||||
data = {}
|
||||
# 1. Load data from each loaders
|
||||
for name, loader in self.loader_dict.items():
|
||||
try:
|
||||
value = loader[uid]
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = np.array(value)
|
||||
if not isinstance(
|
||||
value, (np.ndarray, torch.Tensor, str, numbers.Number)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
|
||||
)
|
||||
except Exception:
|
||||
path, _type = self.debug_info[name]
|
||||
logging.error(
|
||||
f"Error happened with path={path}, type={_type}, id={uid}"
|
||||
)
|
||||
raise
|
||||
|
||||
# torch.Tensor is converted to ndarray
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.numpy()
|
||||
elif isinstance(value, numbers.Number):
|
||||
value = np.array([value])
|
||||
data[name] = value
|
||||
|
||||
# 2. [Option] Apply preprocessing
|
||||
# e.g. funasr_local.train.preprocessor:CommonPreprocessor
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
|
||||
# 3. Force data-precision
|
||||
for name in data:
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f"All values must be converted to np.ndarray object "
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.'
|
||||
)
|
||||
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == "f":
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == "i":
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
|
||||
data[name] = value
|
||||
|
||||
if self.cache is not None and self.cache.size < self.max_cache_size:
|
||||
self.cache[uid] = data
|
||||
|
||||
retval = uid, data
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
388
funasr_local/datasets/iterable_dataset.py
Normal file
388
funasr_local/datasets/iterable_dataset.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""Iterable dataset module."""
|
||||
import copy
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import Iterator
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import List
|
||||
|
||||
# import kaldiio
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
from typeguard import check_argument_types
|
||||
import os.path
|
||||
|
||||
from funasr_local.datasets.dataset import ESPnetDataset
|
||||
|
||||
|
||||
SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
|
||||
|
||||
def load_kaldi(input):
|
||||
retval = kaldiio.load_mat(input)
|
||||
if isinstance(retval, tuple):
|
||||
assert len(retval) == 2, len(retval)
|
||||
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
|
||||
# sound scp case
|
||||
rate, array = retval
|
||||
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
|
||||
# Extended ark format case
|
||||
array, rate = retval
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected type: {type(retval[0])}, {type(retval[1])}")
|
||||
|
||||
# Multichannel wave fie
|
||||
# array: (NSample, Channel) or (Nsample)
|
||||
|
||||
else:
|
||||
# Normal ark case
|
||||
assert isinstance(retval, np.ndarray), type(retval)
|
||||
array = retval
|
||||
return array
|
||||
|
||||
|
||||
def load_bytes(input):
|
||||
middle_data = np.frombuffer(input, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in 'iu':
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype('float32')
|
||||
if dtype.kind != 'f':
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2 ** (i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
return array
|
||||
|
||||
def load_pcm(input):
|
||||
with open(input,"rb") as f:
|
||||
bytes = f.read()
|
||||
return load_bytes(bytes)
|
||||
|
||||
DATA_TYPES = {
|
||||
"sound": lambda x: torchaudio.load(x)[0].numpy(),
|
||||
"pcm": load_pcm,
|
||||
"kaldi_ark": load_kaldi,
|
||||
"bytes": load_bytes,
|
||||
"waveform": lambda x: x,
|
||||
"npy": np.load,
|
||||
"text_int": lambda x: np.loadtxt(
|
||||
StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
|
||||
),
|
||||
"csv_int": lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=","),
|
||||
"text_float": lambda x: np.loadtxt(
|
||||
StringIO(x), ndmin=1, dtype=np.float32, delimiter=" "
|
||||
),
|
||||
"csv_float": lambda x: np.loadtxt(
|
||||
StringIO(x), ndmin=1, dtype=np.float32, delimiter=","
|
||||
),
|
||||
"text": lambda x: x,
|
||||
}
|
||||
|
||||
|
||||
class IterableESPnetDataset(IterableDataset):
|
||||
"""Pytorch Dataset class for ESPNet.
|
||||
|
||||
Examples:
|
||||
>>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
|
||||
... ('token_int', 'output', 'text_int')],
|
||||
... )
|
||||
>>> for uid, data in dataset:
|
||||
... data
|
||||
{'input': per_utt_array, 'output': per_utt_array}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path_name_type_list: Collection[Tuple[any, str, str]],
|
||||
preprocess: Callable[
|
||||
[str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
|
||||
] = None,
|
||||
float_dtype: str = "float32",
|
||||
fs: dict = None,
|
||||
mc: bool = False,
|
||||
int_dtype: str = "long",
|
||||
key_file: str = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"'
|
||||
)
|
||||
|
||||
path_name_type_list = copy.deepcopy(path_name_type_list)
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.float_dtype = float_dtype
|
||||
self.int_dtype = int_dtype
|
||||
self.key_file = key_file
|
||||
self.fs = fs
|
||||
self.mc = mc
|
||||
|
||||
self.debug_info = {}
|
||||
non_iterable_list = []
|
||||
self.path_name_type_list = []
|
||||
|
||||
if not isinstance(path_name_type_list[0], (Tuple, List)):
|
||||
path = path_name_type_list[0]
|
||||
name = path_name_type_list[1]
|
||||
_type = path_name_type_list[2]
|
||||
self.debug_info[name] = path, _type
|
||||
if _type not in DATA_TYPES:
|
||||
non_iterable_list.append((path, name, _type))
|
||||
else:
|
||||
self.path_name_type_list.append((path, name, _type))
|
||||
else:
|
||||
for path, name, _type in path_name_type_list:
|
||||
self.debug_info[name] = path, _type
|
||||
if _type not in DATA_TYPES:
|
||||
non_iterable_list.append((path, name, _type))
|
||||
else:
|
||||
self.path_name_type_list.append((path, name, _type))
|
||||
|
||||
if len(non_iterable_list) != 0:
|
||||
# Some types doesn't support iterable mode
|
||||
self.non_iterable_dataset = ESPnetDataset(
|
||||
path_name_type_list=non_iterable_list,
|
||||
preprocess=preprocess,
|
||||
float_dtype=float_dtype,
|
||||
int_dtype=int_dtype,
|
||||
)
|
||||
else:
|
||||
self.non_iterable_dataset = None
|
||||
|
||||
self.apply_utt2category = False
|
||||
|
||||
def has_name(self, name) -> bool:
|
||||
return name in self.debug_info
|
||||
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
return tuple(self.debug_info)
|
||||
|
||||
def __repr__(self):
|
||||
_mes = self.__class__.__name__
|
||||
_mes += "("
|
||||
for name, (path, _type) in self.debug_info.items():
|
||||
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
|
||||
_mes += f"\n preprocess: {self.preprocess})"
|
||||
return _mes
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
|
||||
count = 0
|
||||
if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
|
||||
linenum = len(self.path_name_type_list)
|
||||
data = {}
|
||||
for i in range(linenum):
|
||||
value = self.path_name_type_list[i][0]
|
||||
uid = 'utt_id'
|
||||
name = self.path_name_type_list[i][1]
|
||||
_type = self.path_name_type_list[i][2]
|
||||
func = DATA_TYPES[_type]
|
||||
array = func(value)
|
||||
if self.fs is not None and (name == "speech" or name == "ref_speech"):
|
||||
audio_fs = self.fs["audio_fs"]
|
||||
model_fs = self.fs["model_fs"]
|
||||
if audio_fs is not None and model_fs is not None:
|
||||
array = torch.from_numpy(array)
|
||||
array = array.unsqueeze(0)
|
||||
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
|
||||
new_freq=model_fs)(array)
|
||||
array = array.squeeze(0).numpy()
|
||||
|
||||
data[name] = array
|
||||
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
for name in data:
|
||||
count += 1
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f'All values must be converted to np.ndarray object '
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.')
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == 'f':
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == 'i':
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not supported dtype: {value.dtype}')
|
||||
data[name] = value
|
||||
|
||||
yield uid, data
|
||||
|
||||
elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
|
||||
linenum = len(self.path_name_type_list)
|
||||
data = {}
|
||||
for i in range(linenum):
|
||||
value = self.path_name_type_list[i][0]
|
||||
uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0]
|
||||
name = self.path_name_type_list[i][1]
|
||||
_type = self.path_name_type_list[i][2]
|
||||
if _type == "sound":
|
||||
audio_type = os.path.basename(value).lower()
|
||||
if audio_type.rfind(".pcm") >= 0:
|
||||
_type = "pcm"
|
||||
func = DATA_TYPES[_type]
|
||||
array = func(value)
|
||||
if self.fs is not None and (name == "speech" or name == "ref_speech"):
|
||||
audio_fs = self.fs["audio_fs"]
|
||||
model_fs = self.fs["model_fs"]
|
||||
if audio_fs is not None and model_fs is not None:
|
||||
array = torch.from_numpy(array)
|
||||
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
|
||||
new_freq=model_fs)(array)
|
||||
array = array.numpy()
|
||||
|
||||
if _type == "sound":
|
||||
if self.mc:
|
||||
data[name] = array.transpose((1, 0))
|
||||
else:
|
||||
data[name] = array[0]
|
||||
else:
|
||||
data[name] = array
|
||||
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
for name in data:
|
||||
count += 1
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f'All values must be converted to np.ndarray object '
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.')
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == 'f':
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == 'i':
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not supported dtype: {value.dtype}')
|
||||
data[name] = value
|
||||
|
||||
yield uid, data
|
||||
|
||||
else:
|
||||
if self.key_file is not None:
|
||||
uid_iter = (
|
||||
line.rstrip().split(maxsplit=1)[0]
|
||||
for line in open(self.key_file, encoding="utf-8")
|
||||
)
|
||||
elif len(self.path_name_type_list) != 0:
|
||||
uid_iter = (
|
||||
line.rstrip().split(maxsplit=1)[0]
|
||||
for line in open(self.path_name_type_list[0][0], encoding="utf-8")
|
||||
)
|
||||
else:
|
||||
uid_iter = iter(self.non_iterable_dataset)
|
||||
|
||||
files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
|
||||
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
linenum = 0
|
||||
for count, uid in enumerate(uid_iter, 1):
|
||||
# If num_workers>=1, split keys
|
||||
if worker_info is not None:
|
||||
if (count - 1) % worker_info.num_workers != worker_info.id:
|
||||
continue
|
||||
|
||||
# 1. Read a line from each file
|
||||
while True:
|
||||
keys = []
|
||||
values = []
|
||||
for f in files:
|
||||
linenum += 1
|
||||
try:
|
||||
line = next(f)
|
||||
except StopIteration:
|
||||
raise RuntimeError(f"{uid} is not found in the files")
|
||||
sps = line.rstrip().split(maxsplit=1)
|
||||
if len(sps) != 2:
|
||||
raise RuntimeError(
|
||||
f"This line doesn't include a space:"
|
||||
f" {f}:L{linenum}: {line})"
|
||||
)
|
||||
key, value = sps
|
||||
keys.append(key)
|
||||
values.append(value)
|
||||
|
||||
for k_idx, k in enumerate(keys):
|
||||
if k != keys[0]:
|
||||
raise RuntimeError(
|
||||
f"Keys are mismatched. Text files (idx={k_idx}) is "
|
||||
f"not sorted or not having same keys at L{linenum}"
|
||||
)
|
||||
|
||||
# If the key is matched, break the loop
|
||||
if len(keys) == 0 or keys[0] == uid:
|
||||
break
|
||||
|
||||
# 2. Load the entry from each line and create a dict
|
||||
data = {}
|
||||
# 2.a. Load data streamingly
|
||||
for value, (path, name, _type) in zip(values, self.path_name_type_list):
|
||||
if _type == "sound":
|
||||
audio_type = os.path.basename(value).lower()
|
||||
if audio_type.rfind(".pcm") >= 0:
|
||||
_type = "pcm"
|
||||
func = DATA_TYPES[_type]
|
||||
# Load entry
|
||||
array = func(value)
|
||||
if self.fs is not None and name == "speech":
|
||||
audio_fs = self.fs["audio_fs"]
|
||||
model_fs = self.fs["model_fs"]
|
||||
if audio_fs is not None and model_fs is not None:
|
||||
array = torch.from_numpy(array)
|
||||
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
|
||||
new_freq=model_fs)(array)
|
||||
array = array.numpy()
|
||||
if _type == "sound":
|
||||
if self.mc:
|
||||
data[name] = array.transpose((1, 0))
|
||||
else:
|
||||
data[name] = array[0]
|
||||
else:
|
||||
data[name] = array
|
||||
if self.non_iterable_dataset is not None:
|
||||
# 2.b. Load data from non-iterable dataset
|
||||
_, from_non_iterable = self.non_iterable_dataset[uid]
|
||||
data.update(from_non_iterable)
|
||||
|
||||
# 3. [Option] Apply preprocessing
|
||||
# e.g. funasr_local.train.preprocessor:CommonPreprocessor
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
|
||||
# 4. Force data-precision
|
||||
for name in data:
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f"All values must be converted to np.ndarray object "
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.'
|
||||
)
|
||||
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == "f":
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == "i":
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
|
||||
data[name] = value
|
||||
|
||||
yield uid, data
|
||||
|
||||
if count == 0:
|
||||
raise RuntimeError("No iteration")
|
||||
|
||||
349
funasr_local/datasets/iterable_dataset_modelscope.py
Normal file
349
funasr_local/datasets/iterable_dataset_modelscope.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
"""Iterable dataset module."""
|
||||
import copy
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Collection, Dict, Iterator, Tuple, Union
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import torch
|
||||
from funasr_local.datasets.dataset import ESPnetDataset
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.utils import wav_utils
|
||||
|
||||
|
||||
def load_kaldi(input):
|
||||
retval = kaldiio.load_mat(input)
|
||||
if isinstance(retval, tuple):
|
||||
assert len(retval) == 2, len(retval)
|
||||
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
|
||||
# sound scp case
|
||||
rate, array = retval
|
||||
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
|
||||
# Extended ark format case
|
||||
array, rate = retval
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
|
||||
|
||||
# Multichannel wave fie
|
||||
# array: (NSample, Channel) or (Nsample)
|
||||
|
||||
else:
|
||||
# Normal ark case
|
||||
assert isinstance(retval, np.ndarray), type(retval)
|
||||
array = retval
|
||||
return array
|
||||
|
||||
|
||||
DATA_TYPES = {
|
||||
'sound':
|
||||
lambda x: soundfile.read(x)[0],
|
||||
'kaldi_ark':
|
||||
load_kaldi,
|
||||
'npy':
|
||||
np.load,
|
||||
'text_int':
|
||||
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
|
||||
'csv_int':
|
||||
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
|
||||
'text_float':
|
||||
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
|
||||
),
|
||||
'csv_float':
|
||||
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
|
||||
),
|
||||
'text':
|
||||
lambda x: x,
|
||||
}
|
||||
|
||||
|
||||
class IterableESPnetDatasetModelScope(IterableDataset):
|
||||
"""Pytorch Dataset class for ESPNet.
|
||||
|
||||
Examples:
|
||||
>>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
|
||||
... ('token_int', 'output', 'text_int')],
|
||||
... )
|
||||
>>> for uid, data in dataset:
|
||||
... data
|
||||
{'input': per_utt_array, 'output': per_utt_array}
|
||||
"""
|
||||
def __init__(self,
|
||||
path_name_type_list: Collection[Tuple[any, str, str]],
|
||||
preprocess: Callable[[str, Dict[str, np.ndarray]],
|
||||
Dict[str, np.ndarray]] = None,
|
||||
float_dtype: str = 'float32',
|
||||
int_dtype: str = 'long',
|
||||
key_file: str = None,
|
||||
sample_rate: Union[dict, int] = 16000):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"')
|
||||
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.float_dtype = float_dtype
|
||||
self.int_dtype = int_dtype
|
||||
self.key_file = key_file
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.debug_info = {}
|
||||
non_iterable_list = []
|
||||
self.path_name_type_list = []
|
||||
|
||||
path_list = path_name_type_list[0]
|
||||
name = path_name_type_list[1]
|
||||
_type = path_name_type_list[2]
|
||||
if name in self.debug_info:
|
||||
raise RuntimeError(f'"{name}" is duplicated for data-key')
|
||||
self.debug_info[name] = path_list, _type
|
||||
# for path, name, _type in path_name_type_list:
|
||||
for path in path_list:
|
||||
self.path_name_type_list.append((path, name, _type))
|
||||
|
||||
if len(non_iterable_list) != 0:
|
||||
# Some types doesn't support iterable mode
|
||||
self.non_iterable_dataset = ESPnetDataset(
|
||||
path_name_type_list=non_iterable_list,
|
||||
preprocess=preprocess,
|
||||
float_dtype=float_dtype,
|
||||
int_dtype=int_dtype,
|
||||
)
|
||||
else:
|
||||
self.non_iterable_dataset = None
|
||||
|
||||
self.apply_utt2category = False
|
||||
|
||||
def has_name(self, name) -> bool:
|
||||
return name in self.debug_info
|
||||
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
return tuple(self.debug_info)
|
||||
|
||||
def __repr__(self):
|
||||
_mes = self.__class__.__name__
|
||||
_mes += '('
|
||||
for name, (path, _type) in self.debug_info.items():
|
||||
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
|
||||
_mes += f'\n preprocess: {self.preprocess})'
|
||||
return _mes
|
||||
|
||||
def __iter__(
|
||||
self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
|
||||
torch.set_printoptions(profile='default')
|
||||
count = len(self.path_name_type_list)
|
||||
for idx in range(count):
|
||||
# 2. Load the entry from each line and create a dict
|
||||
data = {}
|
||||
# 2.a. Load data streamingly
|
||||
|
||||
# value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
|
||||
value = self.path_name_type_list[idx][0]['file']
|
||||
uid = self.path_name_type_list[idx][0]['key']
|
||||
# name: speech
|
||||
name = self.path_name_type_list[idx][1]
|
||||
_type = self.path_name_type_list[idx][2]
|
||||
func = DATA_TYPES[_type]
|
||||
array = func(value)
|
||||
|
||||
# 2.b. audio resample
|
||||
if _type == 'sound':
|
||||
audio_sr: int = 16000
|
||||
model_sr: int = 16000
|
||||
if isinstance(self.sample_rate, int):
|
||||
model_sr = self.sample_rate
|
||||
else:
|
||||
if 'audio_sr' in self.sample_rate:
|
||||
audio_sr = self.sample_rate['audio_sr']
|
||||
if 'model_sr' in self.sample_rate:
|
||||
model_sr = self.sample_rate['model_sr']
|
||||
array = wav_utils.torch_resample(array, audio_sr, model_sr)
|
||||
|
||||
# array: [ 1.25122070e-03 ... ]
|
||||
data[name] = array
|
||||
|
||||
# 3. [Option] Apply preprocessing
|
||||
# e.g. espnet2.train.preprocessor:CommonPreprocessor
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
# data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
|
||||
|
||||
# 4. Force data-precision
|
||||
for name in data:
|
||||
# value is np.ndarray data
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f'All values must be converted to np.ndarray object '
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.'
|
||||
)
|
||||
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == 'f':
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == 'i':
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not supported dtype: {value.dtype}')
|
||||
data[name] = value
|
||||
|
||||
yield uid, data
|
||||
|
||||
if count == 0:
|
||||
raise RuntimeError('No iteration')
|
||||
|
||||
|
||||
class IterableESPnetBytesModelScope(IterableDataset):
|
||||
"""Pytorch audio bytes class for ESPNet.
|
||||
|
||||
Examples:
|
||||
>>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
|
||||
... ('token_int', 'output', 'text_int')],
|
||||
... )
|
||||
>>> for uid, data in dataset:
|
||||
... data
|
||||
{'input': per_utt_array, 'output': per_utt_array}
|
||||
"""
|
||||
def __init__(self,
|
||||
path_name_type_list: Collection[Tuple[any, str, str]],
|
||||
preprocess: Callable[[str, Dict[str, np.ndarray]],
|
||||
Dict[str, np.ndarray]] = None,
|
||||
float_dtype: str = 'float32',
|
||||
int_dtype: str = 'long',
|
||||
key_file: str = None,
|
||||
sample_rate: Union[dict, int] = 16000):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"')
|
||||
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.float_dtype = float_dtype
|
||||
self.int_dtype = int_dtype
|
||||
self.key_file = key_file
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.debug_info = {}
|
||||
non_iterable_list = []
|
||||
self.path_name_type_list = []
|
||||
|
||||
audio_data = path_name_type_list[0]
|
||||
name = path_name_type_list[1]
|
||||
_type = path_name_type_list[2]
|
||||
if name in self.debug_info:
|
||||
raise RuntimeError(f'"{name}" is duplicated for data-key')
|
||||
self.debug_info[name] = audio_data, _type
|
||||
self.path_name_type_list.append((audio_data, name, _type))
|
||||
|
||||
if len(non_iterable_list) != 0:
|
||||
# Some types doesn't support iterable mode
|
||||
self.non_iterable_dataset = ESPnetDataset(
|
||||
path_name_type_list=non_iterable_list,
|
||||
preprocess=preprocess,
|
||||
float_dtype=float_dtype,
|
||||
int_dtype=int_dtype,
|
||||
)
|
||||
else:
|
||||
self.non_iterable_dataset = None
|
||||
|
||||
self.apply_utt2category = False
|
||||
|
||||
if float_dtype == 'float32':
|
||||
self.np_dtype = np.float32
|
||||
|
||||
def has_name(self, name) -> bool:
|
||||
return name in self.debug_info
|
||||
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
return tuple(self.debug_info)
|
||||
|
||||
def __repr__(self):
|
||||
_mes = self.__class__.__name__
|
||||
_mes += '('
|
||||
for name, (path, _type) in self.debug_info.items():
|
||||
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
|
||||
_mes += f'\n preprocess: {self.preprocess})'
|
||||
return _mes
|
||||
|
||||
def __iter__(
|
||||
self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
|
||||
|
||||
torch.set_printoptions(profile='default')
|
||||
# 2. Load the entry from each line and create a dict
|
||||
data = {}
|
||||
# 2.a. Load data streamingly
|
||||
|
||||
value = self.path_name_type_list[0][0]
|
||||
uid = 'pcm_data'
|
||||
# name: speech
|
||||
name = self.path_name_type_list[0][1]
|
||||
_type = self.path_name_type_list[0][2]
|
||||
func = DATA_TYPES[_type]
|
||||
# array: [ 1.25122070e-03 ... ]
|
||||
# data[name] = np.frombuffer(value, dtype=self.np_dtype)
|
||||
|
||||
# 2.b. byte(PCM16) to float32
|
||||
middle_data = np.frombuffer(value, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in 'iu':
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype('float32')
|
||||
if dtype.kind != 'f':
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2**(i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
|
||||
dtype=self.np_dtype)
|
||||
|
||||
# 2.c. audio resample
|
||||
if _type == 'sound':
|
||||
audio_sr: int = 16000
|
||||
model_sr: int = 16000
|
||||
if isinstance(self.sample_rate, int):
|
||||
model_sr = self.sample_rate
|
||||
else:
|
||||
if 'audio_sr' in self.sample_rate:
|
||||
audio_sr = self.sample_rate['audio_sr']
|
||||
if 'model_sr' in self.sample_rate:
|
||||
model_sr = self.sample_rate['model_sr']
|
||||
array = wav_utils.torch_resample(array, audio_sr, model_sr)
|
||||
|
||||
data[name] = array
|
||||
|
||||
# 3. [Option] Apply preprocessing
|
||||
# e.g. espnet2.train.preprocessor:CommonPreprocessor
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
# data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
|
||||
|
||||
# 4. Force data-precision
|
||||
for name in data:
|
||||
# value is np.ndarray data
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f'All values must be converted to np.ndarray object '
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.')
|
||||
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == 'f':
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == 'i':
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not supported dtype: {value.dtype}')
|
||||
data[name] = value
|
||||
|
||||
yield uid, data
|
||||
0
funasr_local/datasets/large_datasets/__init__.py
Normal file
0
funasr_local/datasets/large_datasets/__init__.py
Normal file
96
funasr_local/datasets/large_datasets/build_dataloader.py
Normal file
96
funasr_local/datasets/large_datasets/build_dataloader.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import sentencepiece as spm
|
||||
from torch.utils.data import DataLoader
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.datasets.large_datasets.dataset import Dataset
|
||||
from funasr_local.iterators.abs_iter_factory import AbsIterFactory
|
||||
from funasr_local.text.abs_tokenizer import AbsTokenizer
|
||||
|
||||
|
||||
def read_symbol_table(symbol_table_file):
|
||||
if isinstance(symbol_table_file, str):
|
||||
symbol_table = {}
|
||||
with open(symbol_table_file, "r", encoding="utf8") as fin:
|
||||
for i, line in enumerate(fin):
|
||||
char = line.strip()
|
||||
symbol_table[char] = i
|
||||
else:
|
||||
assert isinstance(symbol_table_file, list)
|
||||
symbol_table = {}
|
||||
for i, char in enumerate(symbol_table_file):
|
||||
symbol_table[char] = i
|
||||
return symbol_table
|
||||
|
||||
|
||||
def load_seg_dict(seg_dict_file):
|
||||
seg_dict = {}
|
||||
assert isinstance(seg_dict_file, str)
|
||||
with open(seg_dict_file, "r", encoding="utf8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
s = line.strip().split()
|
||||
key = s[0]
|
||||
value = s[1:]
|
||||
seg_dict[key] = " ".join(value)
|
||||
return seg_dict
|
||||
|
||||
|
||||
class SentencepiecesTokenizer(AbsTokenizer):
|
||||
def __init__(self, model: Union[Path, str]):
|
||||
assert check_argument_types()
|
||||
self.model = str(model)
|
||||
self.sp = None
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}(model="{self.model}")'
|
||||
|
||||
def _build_sentence_piece_processor(self):
|
||||
if self.sp is None:
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.load(self.model)
|
||||
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
self._build_sentence_piece_processor()
|
||||
return self.sp.EncodeAsPieces(line)
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
self._build_sentence_piece_processor()
|
||||
return self.sp.DecodePieces(list(tokens))
|
||||
|
||||
|
||||
class ArkDataLoader(AbsIterFactory):
|
||||
def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
|
||||
bpemodel_file=None, mode="train"):
|
||||
symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
|
||||
if seg_dict_file is not None:
|
||||
seg_dict = load_seg_dict(seg_dict_file)
|
||||
else:
|
||||
seg_dict = None
|
||||
if punc_dict_file is not None:
|
||||
punc_dict = read_symbol_table(punc_dict_file)
|
||||
else:
|
||||
punc_dict = None
|
||||
self.dataset_conf = dataset_conf
|
||||
self.frontend_conf = frontend_conf
|
||||
logging.info("dataloader config: {}".format(self.dataset_conf))
|
||||
batch_mode = self.dataset_conf.get("batch_mode", "padding")
|
||||
if bpemodel_file is not None:
|
||||
bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
|
||||
else:
|
||||
bpe_tokenizer = None
|
||||
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
|
||||
self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
|
||||
|
||||
def build_iter(self, epoch, shuffle=True):
|
||||
self.dataset.set_epoch(epoch)
|
||||
data_loader = DataLoader(self.dataset,
|
||||
batch_size=None,
|
||||
pin_memory=True,
|
||||
num_workers=self.dataset_conf.get("num_workers", 8))
|
||||
return data_loader
|
||||
213
funasr_local/datasets/large_datasets/datapipes/batch.py
Normal file
213
funasr_local/datasets/large_datasets/datapipes/batch.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import random
|
||||
|
||||
from itertools import count
|
||||
from functools import partial
|
||||
from torch.utils.data import IterableDataset
|
||||
from funasr_local.datasets.large_datasets.datapipes.map import MapperIterDataPipe
|
||||
|
||||
tiebreaker = count()
|
||||
|
||||
|
||||
def _default_len_fn(token):
|
||||
return len(token), next(tiebreaker)
|
||||
|
||||
|
||||
def _token_len_fn(token, len_fn):
|
||||
return len_fn(token), next(tiebreaker), token
|
||||
|
||||
|
||||
class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datapipe,
|
||||
batch_size=8000,
|
||||
len_fn=_default_len_fn,
|
||||
buffer_size=10240,
|
||||
sort_size=500,
|
||||
batch_mode="padding",
|
||||
):
|
||||
assert batch_size > 0, "Batch size is required to be larger than 0!"
|
||||
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
|
||||
assert sort_size > 0, "Sort size is required to be larger than 0!"
|
||||
|
||||
datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
|
||||
self.datapipe = datapipe
|
||||
self.batch_size = batch_size
|
||||
self.buffer_size = buffer_size
|
||||
self.sort_size = sort_size
|
||||
self.batch_mode = batch_mode
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def __iter__(self):
|
||||
buffer = []
|
||||
batch = []
|
||||
bucket = []
|
||||
max_lengths = 0
|
||||
min_lengths = 999999
|
||||
batch_lengths = 0
|
||||
|
||||
if self.batch_mode == "clipping":
|
||||
assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
if len(buffer) == self.buffer_size:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if buffer:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if bucket:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
else:
|
||||
if self.buffer_size == -1:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
buffer.sort()
|
||||
for sample in buffer:
|
||||
length, _, token = sample
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
bucket.append(batch)
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
random.shuffle(bucket)
|
||||
if bucket:
|
||||
for batch_sample in bucket:
|
||||
yield batch_sample
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
elif self.buffer_size == 0:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
length, _, token = d
|
||||
if length > self.batch_size:
|
||||
continue
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
else:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
if len(buffer) == self.buffer_size:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if buffer:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if bucket:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
24
funasr_local/datasets/large_datasets/datapipes/filter.py
Normal file
24
funasr_local/datasets/large_datasets/datapipes/filter.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
def default_fn(data):
|
||||
return data
|
||||
|
||||
|
||||
class FilterIterDataPipe(IterableDataset):
|
||||
|
||||
def __init__(self,
|
||||
datapipe,
|
||||
fn=default_fn):
|
||||
self.datapipe = datapipe
|
||||
self.fn = fn
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def __iter__(self):
|
||||
assert callable(self.fn)
|
||||
for data in self.datapipe:
|
||||
if self.fn(data):
|
||||
yield data
|
||||
else:
|
||||
continue
|
||||
22
funasr_local/datasets/large_datasets/datapipes/map.py
Normal file
22
funasr_local/datasets/large_datasets/datapipes/map.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
def default_fn(data):
|
||||
return data
|
||||
|
||||
|
||||
class MapperIterDataPipe(IterableDataset):
|
||||
|
||||
def __init__(self,
|
||||
datapipe,
|
||||
fn=default_fn):
|
||||
self.datapipe = datapipe
|
||||
self.fn = fn
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def __iter__(self):
|
||||
assert callable(self.fn)
|
||||
for data in self.datapipe:
|
||||
yield self.fn(data)
|
||||
212
funasr_local/datasets/large_datasets/dataset.py
Normal file
212
funasr_local/datasets/large_datasets/dataset.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import random
|
||||
import numpy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import torch.distributed as dist
|
||||
from kaldiio import ReadHelper
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from funasr_local.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
|
||||
from funasr_local.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
|
||||
from funasr_local.datasets.large_datasets.datapipes.map import MapperIterDataPipe
|
||||
from funasr_local.datasets.large_datasets.utils.filter import filter
|
||||
from funasr_local.datasets.large_datasets.utils.padding import padding
|
||||
from funasr_local.datasets.large_datasets.utils.clipping import clipping
|
||||
from funasr_local.datasets.large_datasets.utils.tokenize import tokenize
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
parts = line.strip()
|
||||
lists.append(parts)
|
||||
return lists
|
||||
|
||||
|
||||
class AudioDataset(IterableDataset):
|
||||
def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
|
||||
self.scp_lists = scp_lists
|
||||
self.data_names = data_names
|
||||
self.data_types = data_types
|
||||
self.frontend_conf = frontend_conf
|
||||
self.shuffle = shuffle
|
||||
self.mode = mode
|
||||
self.epoch = -1
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def get_rank_data_list(self, data_index):
|
||||
assert dist.is_available()
|
||||
if dist.is_initialized():
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
else:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
|
||||
if self.mode == "train":
|
||||
if self.shuffle:
|
||||
random.seed(self.epoch)
|
||||
random.shuffle(data_index)
|
||||
return data_index[self.rank::self.world_size]
|
||||
|
||||
return data_index
|
||||
|
||||
def get_worker_data_list(self, rank_data_index):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
else:
|
||||
self.worker_id = worker_info.id
|
||||
self.num_workers = worker_info.num_workers
|
||||
|
||||
return rank_data_index[self.worker_id::self.num_workers]
|
||||
|
||||
def close_reader(self, reader_list):
|
||||
for reader in reader_list:
|
||||
reader.close()
|
||||
|
||||
def __iter__(self):
|
||||
data_index = list(range(len(self.scp_lists)))
|
||||
rank_data_index = self.get_rank_data_list(data_index)
|
||||
worker_data_index = self.get_worker_data_list(rank_data_index)
|
||||
|
||||
for index in worker_data_index:
|
||||
data = dict(scp=self.scp_lists[index])
|
||||
|
||||
assert 'scp' in data
|
||||
scp = data['scp']
|
||||
data_file_list = scp.strip().split()
|
||||
data_name_list = self.data_names.split(",")
|
||||
data_type_list = self.data_types.split(",")
|
||||
|
||||
for file in data_file_list:
|
||||
assert os.path.exists(file), "{} not exists".format(file)
|
||||
|
||||
assert len(data_file_list) == len(data_name_list) == len(data_type_list), \
|
||||
"The item number of data, data_names, data_types must be the same "
|
||||
|
||||
reader_list = []
|
||||
for data_file, data_type in zip(data_file_list, data_type_list):
|
||||
if data_type == "kaldi_ark":
|
||||
ark_reader = ReadHelper('ark:{}'.format(data_file))
|
||||
reader_list.append(ark_reader)
|
||||
elif data_type == "text" or data_type == "sound":
|
||||
text_reader = open(data_file, "r")
|
||||
reader_list.append(text_reader)
|
||||
elif data_type == "none":
|
||||
continue
|
||||
else:
|
||||
raise TypeError("Data type {} is not supported".format(data_type))
|
||||
|
||||
for items in zip(*reader_list):
|
||||
sample_dict = {}
|
||||
for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
|
||||
if data_type == "kaldi_ark":
|
||||
key, mat = item
|
||||
sample_dict[data_name] = mat
|
||||
if data_name == "speech":
|
||||
sample_dict["key"] = key
|
||||
elif data_type == "sound":
|
||||
key, path = item.strip().split()
|
||||
waveform, sampling_rate = torchaudio.load(path)
|
||||
if self.frontend_conf is not None:
|
||||
if sampling_rate != self.frontend_conf["fs"]:
|
||||
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
|
||||
new_freq=self.frontend_conf["fs"])(waveform)
|
||||
sampling_rate = self.frontend_conf["fs"]
|
||||
waveform = waveform.numpy()
|
||||
mat = waveform[0]
|
||||
sample_dict[data_name] = mat
|
||||
sample_dict["sampling_rate"] = sampling_rate
|
||||
if data_name == "speech":
|
||||
sample_dict["key"] = key
|
||||
else:
|
||||
text = item
|
||||
segs = text.strip().split()
|
||||
sample_dict[data_name] = segs[1:]
|
||||
if "key" not in sample_dict:
|
||||
sample_dict["key"] = segs[0]
|
||||
yield sample_dict
|
||||
|
||||
self.close_reader(reader_list)
|
||||
|
||||
|
||||
def len_fn_example(data):
|
||||
return 1
|
||||
|
||||
|
||||
def len_fn_token(data):
|
||||
assert "speech" in data
|
||||
if "sampling_rate" in data:
|
||||
return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
else:
|
||||
return data["speech"].shape[0]
|
||||
|
||||
|
||||
def Dataset(data_list_file,
|
||||
dict,
|
||||
seg_dict,
|
||||
punc_dict,
|
||||
bpe_tokenizer,
|
||||
conf,
|
||||
frontend_conf,
|
||||
mode="train",
|
||||
batch_mode="padding"):
|
||||
scp_lists = read_lists(data_list_file)
|
||||
shuffle = conf.get('shuffle', True)
|
||||
data_names = conf.get("data_names", "speech,text")
|
||||
data_types = conf.get("data_types", "kaldi_ark,text")
|
||||
dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode)
|
||||
|
||||
filter_conf = conf.get('filter_conf', {})
|
||||
filter_fn = partial(filter, **filter_conf)
|
||||
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
|
||||
|
||||
if "text" in data_names:
|
||||
vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
|
||||
tokenize_fn = partial(tokenize, **vocab)
|
||||
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
|
||||
|
||||
if shuffle:
|
||||
buffer_conf = conf.get('shuffle_conf', {})
|
||||
buffer_size = buffer_conf['shuffle_size']
|
||||
sort_size = buffer_conf['sort_size']
|
||||
else:
|
||||
buffer_size = 0
|
||||
sort_size = 1
|
||||
|
||||
batch_conf = conf.get('batch_conf', {})
|
||||
batch_size = batch_conf['batch_size']
|
||||
batch_type = batch_conf['batch_type']
|
||||
|
||||
assert batch_type in ["example", "token"]
|
||||
if batch_type == 'example':
|
||||
len_fn = len_fn_example
|
||||
else:
|
||||
len_fn = len_fn_token
|
||||
|
||||
dataset = MaxTokenBucketizerIterDataPipe(dataset,
|
||||
batch_size=batch_size,
|
||||
len_fn=len_fn,
|
||||
buffer_size=buffer_size,
|
||||
sort_size=sort_size,
|
||||
batch_mode=batch_mode)
|
||||
|
||||
int_pad_value = conf.get("int_pad_value", -1)
|
||||
float_pad_value = conf.get("float_pad_value", 0.0)
|
||||
padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
|
||||
padding_fn = partial(padding, **padding_conf)
|
||||
dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
|
||||
|
||||
return dataset
|
||||
40
funasr_local/datasets/large_datasets/utils/clipping.py
Normal file
40
funasr_local/datasets/large_datasets/utils/clipping.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr_local.datasets.collate_fn import crop_to_max_size
|
||||
|
||||
|
||||
def clipping(data):
|
||||
assert isinstance(data, list)
|
||||
assert "key" in data[0]
|
||||
|
||||
keys = [x["key"] for x in data]
|
||||
|
||||
batch = {}
|
||||
data_names = data[0].keys()
|
||||
for data_name in data_names:
|
||||
if data_name == "key":
|
||||
continue
|
||||
else:
|
||||
if data[0][data_name].dtype.kind == "i":
|
||||
tensor_type = torch.int64
|
||||
else:
|
||||
tensor_type = torch.float32
|
||||
|
||||
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
|
||||
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
|
||||
|
||||
length_clip = min(tensor_lengths)
|
||||
tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
|
||||
for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
|
||||
diff = length - length_clip
|
||||
assert diff >= 0
|
||||
if diff == 0:
|
||||
tensor_clip[i] = tensor
|
||||
else:
|
||||
tensor_clip[i] = crop_to_max_size(tensor, length_clip)
|
||||
|
||||
batch[data_name] = tensor_clip
|
||||
batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
|
||||
|
||||
return keys, batch
|
||||
26
funasr_local/datasets/large_datasets/utils/filter.py
Normal file
26
funasr_local/datasets/large_datasets/utils/filter.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
|
||||
def filter(data,
|
||||
speech_length_min=100,
|
||||
speech_length_max=15000,
|
||||
token_length_min=0,
|
||||
token_length_max=200):
|
||||
assert "speech" in data or "text" in data
|
||||
|
||||
if "speech" in data and "text" in data:
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
num_tokens = len(data['text'])
|
||||
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
|
||||
elif "speech" in data:
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
return speech_length_min < speech_length < speech_length_max
|
||||
else:
|
||||
num_tokens = len(data['text'])
|
||||
return token_length_min < num_tokens < token_length_max
|
||||
30
funasr_local/datasets/large_datasets/utils/low_frame_rate.py
Normal file
30
funasr_local/datasets/large_datasets/utils/low_frame_rate.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def build_LFR_features(data, m, n):
|
||||
"""
|
||||
Actually, this implements stacking frames and skipping frames.
|
||||
if m = 1 and n = 1, just return the origin features.
|
||||
if m = 1 and n > 1, it works like skipping.
|
||||
if m > 1 and n = 1, it works like stacking but only support right frames.
|
||||
if m > 1 and n > 1, it works like LFR.
|
||||
|
||||
Args:
|
||||
inputs_batch: inputs is T x D np.ndarray
|
||||
m: number of frames to stack
|
||||
n: number of frames to skip
|
||||
"""
|
||||
|
||||
LFR_inputs = []
|
||||
T = data.shape[0]
|
||||
T_lfr = int(np.ceil(T / n))
|
||||
for i in range(T_lfr):
|
||||
if m <= T - i * n:
|
||||
LFR_inputs.append(np.hstack(data[i*n:i*n+m]))
|
||||
else:
|
||||
num_padding = m - (T - i * n)
|
||||
frame = np.hstack(data[i*n:])
|
||||
for _ in range(num_padding):
|
||||
frame = np.hstack((frame, data[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
return np.vstack(LFR_inputs)
|
||||
34
funasr_local/datasets/large_datasets/utils/padding.py
Normal file
34
funasr_local/datasets/large_datasets/utils/padding.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def padding(data, float_pad_value=0.0, int_pad_value=-1):
|
||||
assert isinstance(data, list)
|
||||
assert "key" in data[0]
|
||||
assert "speech" in data[0] or "text" in data[0]
|
||||
|
||||
keys = [x["key"] for x in data]
|
||||
|
||||
batch = {}
|
||||
data_names = data[0].keys()
|
||||
for data_name in data_names:
|
||||
if data_name == "key" or data_name =="sampling_rate":
|
||||
continue
|
||||
else:
|
||||
if data[0][data_name].dtype.kind == "i":
|
||||
pad_value = int_pad_value
|
||||
tensor_type = torch.int64
|
||||
else:
|
||||
pad_value = float_pad_value
|
||||
tensor_type = torch.float32
|
||||
|
||||
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
|
||||
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
|
||||
tensor_pad = pad_sequence(tensor_list,
|
||||
batch_first=True,
|
||||
padding_value=pad_value)
|
||||
batch[data_name] = tensor_pad
|
||||
batch[data_name + "_lengths"] = tensor_lengths
|
||||
|
||||
return keys, batch
|
||||
81
funasr_local/datasets/large_datasets/utils/tokenize.py
Normal file
81
funasr_local/datasets/large_datasets/utils/tokenize.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
def forward_segment(text, seg_dict):
|
||||
word_list = []
|
||||
i = 0
|
||||
while i < len(text):
|
||||
longest_word = text[i]
|
||||
for j in range(i + 1, len(text) + 1):
|
||||
word = text[i:j]
|
||||
if word in seg_dict:
|
||||
if len(word) > len(longest_word):
|
||||
longest_word = word
|
||||
word_list.append(longest_word)
|
||||
i += len(longest_word)
|
||||
return word_list
|
||||
|
||||
def seg_tokenize(txt, seg_dict):
|
||||
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
|
||||
out_txt = ""
|
||||
for word in txt:
|
||||
word = word.lower()
|
||||
if word in seg_dict:
|
||||
out_txt += seg_dict[word] + " "
|
||||
else:
|
||||
if pattern.match(word):
|
||||
for char in word:
|
||||
if char in seg_dict:
|
||||
out_txt += seg_dict[char] + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
return out_txt.strip().split()
|
||||
|
||||
def tokenize(data,
|
||||
vocab=None,
|
||||
seg_dict=None,
|
||||
punc_dict=None,
|
||||
bpe_tokenizer=None):
|
||||
assert "text" in data
|
||||
assert isinstance(vocab, dict)
|
||||
text = data["text"]
|
||||
token = []
|
||||
vad = -2
|
||||
|
||||
if bpe_tokenizer is not None:
|
||||
text = bpe_tokenizer.text2tokens("".join(text))
|
||||
|
||||
if seg_dict is not None:
|
||||
assert isinstance(seg_dict, dict)
|
||||
text = seg_tokenize(text, seg_dict)
|
||||
|
||||
length = len(text)
|
||||
for i in range(length):
|
||||
x = text[i]
|
||||
if i == length-1 and "punc" in data and x.startswith("vad:"):
|
||||
vad = x[4:]
|
||||
if len(vad) == 0:
|
||||
vad = -1
|
||||
else:
|
||||
vad = int(vad)
|
||||
elif x in vocab:
|
||||
token.append(vocab[x])
|
||||
else:
|
||||
token.append(vocab['<unk>'])
|
||||
|
||||
if "punc" in data and punc_dict is not None:
|
||||
punc_token = []
|
||||
for punc in data["punc"]:
|
||||
if punc in punc_dict:
|
||||
punc_token.append(punc_dict[punc])
|
||||
else:
|
||||
punc_token.append(punc_dict["_"])
|
||||
data["punc"] = np.array(punc_token)
|
||||
|
||||
data["text"] = np.array(token)
|
||||
if vad is not -2:
|
||||
data["vad_indexes"]=np.array([vad], dtype=np.int64)
|
||||
return data
|
||||
33
funasr_local/datasets/ms_dataset.py
Normal file
33
funasr_local/datasets/ms_dataset.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
|
||||
|
||||
class MsDataset(object):
|
||||
@classmethod
|
||||
def load_core(cls, data_dir, data_set):
|
||||
wav_file = os.path.join(data_dir, data_set, "wav.scp")
|
||||
text_file = os.path.join(data_dir, data_set, "text")
|
||||
with open(wav_file) as f:
|
||||
wav_lines = f.readlines()
|
||||
with open(text_file) as f:
|
||||
text_lines = f.readlines()
|
||||
data_list = []
|
||||
for wav_line, text_line in zip(wav_lines, text_lines):
|
||||
item = {}
|
||||
item["Audio:FILE"] = wav_line.strip().split()[-1]
|
||||
item["Text:LABEL"] = " ".join(text_line.strip().split()[1:])
|
||||
data_list.append(item)
|
||||
return data_list
|
||||
|
||||
@classmethod
|
||||
def load(cls, dataset_name, namespace="speech_asr", train_set="train", dev_set="validation"):
|
||||
if os.path.exists(dataset_name):
|
||||
data_dir = dataset_name
|
||||
ds_dict = {}
|
||||
ds_dict["train"] = cls.load_core(data_dir, train_set)
|
||||
ds_dict["validation"] = cls.load_core(data_dir, dev_set)
|
||||
ds_dict["raw_data_dir"] = data_dir
|
||||
return ds_dict
|
||||
else:
|
||||
from modelscope.msdatasets import MsDataset
|
||||
ds_dict = MsDataset.load(dataset_name=dataset_name, namespace=namespace)
|
||||
return ds_dict
|
||||
824
funasr_local/datasets/preprocessor.py
Normal file
824
funasr_local/datasets/preprocessor.py
Normal file
@@ -0,0 +1,824 @@
|
||||
import re
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
import soundfile
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr_local.text.build_tokenizer import build_tokenizer
|
||||
from funasr_local.text.cleaner import TextCleaner
|
||||
from funasr_local.text.token_id_converter import TokenIDConverter
|
||||
|
||||
|
||||
class AbsPreprocessor(ABC):
|
||||
def __init__(self, train: bool):
|
||||
self.train = train
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def forward_segment(text, dic):
|
||||
word_list = []
|
||||
i = 0
|
||||
while i < len(text):
|
||||
longest_word = text[i]
|
||||
for j in range(i + 1, len(text) + 1):
|
||||
word = text[i:j]
|
||||
if word in dic:
|
||||
if len(word) > len(longest_word):
|
||||
longest_word = word
|
||||
word_list.append(longest_word)
|
||||
i += len(longest_word)
|
||||
return word_list
|
||||
|
||||
def seg_tokenize(txt, seg_dict):
|
||||
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
|
||||
out_txt = ""
|
||||
for word in txt:
|
||||
word = word.lower()
|
||||
if word in seg_dict:
|
||||
out_txt += seg_dict[word] + " "
|
||||
else:
|
||||
if pattern.match(word):
|
||||
for char in word:
|
||||
if char in seg_dict:
|
||||
out_txt += seg_dict[char] + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
return out_txt.strip().split()
|
||||
|
||||
def seg_tokenize_wo_pattern(txt, seg_dict):
|
||||
out_txt = ""
|
||||
for word in txt:
|
||||
if word in seg_dict:
|
||||
out_txt += seg_dict[word] + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
return out_txt.strip().split()
|
||||
|
||||
|
||||
def framing(
|
||||
x,
|
||||
frame_length: int = 512,
|
||||
frame_shift: int = 256,
|
||||
centered: bool = True,
|
||||
padded: bool = True,
|
||||
):
|
||||
if x.size == 0:
|
||||
raise ValueError("Input array size is zero")
|
||||
if frame_length < 1:
|
||||
raise ValueError("frame_length must be a positive integer")
|
||||
if frame_length > x.shape[-1]:
|
||||
raise ValueError("frame_length is greater than input length")
|
||||
if 0 >= frame_shift:
|
||||
raise ValueError("frame_shift must be greater than 0")
|
||||
|
||||
if centered:
|
||||
pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
|
||||
(frame_length // 2, frame_length // 2)
|
||||
]
|
||||
x = np.pad(x, pad_shape, mode="constant", constant_values=0)
|
||||
|
||||
if padded:
|
||||
# Pad to integer number of windowed segments
|
||||
# I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
|
||||
# with integer nseg
|
||||
nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
|
||||
pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
|
||||
x = np.pad(x, pad_shape, mode="constant", constant_values=0)
|
||||
|
||||
# Created strided array of data segments
|
||||
if frame_length == 1 and frame_length == frame_shift:
|
||||
result = x[..., None]
|
||||
else:
|
||||
shape = x.shape[:-1] + (
|
||||
(x.shape[-1] - frame_length) // frame_shift + 1,
|
||||
frame_length,
|
||||
)
|
||||
strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
|
||||
result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
|
||||
return result
|
||||
|
||||
|
||||
def detect_non_silence(
|
||||
x: np.ndarray,
|
||||
threshold: float = 0.01,
|
||||
frame_length: int = 1024,
|
||||
frame_shift: int = 512,
|
||||
window: str = "boxcar",
|
||||
) -> np.ndarray:
|
||||
"""Power based voice activity detection.
|
||||
|
||||
Args:
|
||||
x: (Channel, Time)
|
||||
>>> x = np.random.randn(1000)
|
||||
>>> detect = detect_non_silence(x)
|
||||
>>> assert x.shape == detect.shape
|
||||
>>> assert detect.dtype == np.bool
|
||||
"""
|
||||
if x.shape[-1] < frame_length:
|
||||
return np.full(x.shape, fill_value=True, dtype=np.bool)
|
||||
|
||||
if x.dtype.kind == "i":
|
||||
x = x.astype(np.float64)
|
||||
# framed_w: (C, T, F)
|
||||
framed_w = framing(
|
||||
x,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
centered=False,
|
||||
padded=True,
|
||||
)
|
||||
framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
|
||||
# power: (C, T)
|
||||
power = (framed_w ** 2).mean(axis=-1)
|
||||
# mean_power: (C, 1)
|
||||
mean_power = np.mean(power, axis=-1, keepdims=True)
|
||||
if np.all(mean_power == 0):
|
||||
return np.full(x.shape, fill_value=True, dtype=np.bool)
|
||||
# detect_frames: (C, T)
|
||||
detect_frames = power / mean_power > threshold
|
||||
# detects: (C, T, F)
|
||||
detects = np.broadcast_to(
|
||||
detect_frames[..., None], detect_frames.shape + (frame_shift,)
|
||||
)
|
||||
# detects: (C, TF)
|
||||
detects = detects.reshape(*detect_frames.shape[:-1], -1)
|
||||
# detects: (C, TF)
|
||||
return np.pad(
|
||||
detects,
|
||||
[(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
|
||||
mode="edge",
|
||||
)
|
||||
|
||||
|
||||
class CommonPreprocessor(AbsPreprocessor):
|
||||
def __init__(
|
||||
self,
|
||||
train: bool,
|
||||
token_type: str = None,
|
||||
token_list: Union[Path, str, Iterable[str]] = None,
|
||||
bpemodel: Union[Path, str, Iterable[str]] = None,
|
||||
text_cleaner: Collection[str] = None,
|
||||
g2p_type: str = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
space_symbol: str = "<space>",
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
delimiter: str = None,
|
||||
rir_scp: str = None,
|
||||
rir_apply_prob: float = 1.0,
|
||||
noise_scp: str = None,
|
||||
noise_apply_prob: float = 1.0,
|
||||
noise_db_range: str = "3_10",
|
||||
speech_volume_normalize: float = None,
|
||||
speech_name: str = "speech",
|
||||
text_name: str = "text",
|
||||
split_with_space: bool = False,
|
||||
seg_dict_file: str = None,
|
||||
):
|
||||
super().__init__(train)
|
||||
self.train = train
|
||||
self.speech_name = speech_name
|
||||
self.text_name = text_name
|
||||
self.speech_volume_normalize = speech_volume_normalize
|
||||
self.rir_apply_prob = rir_apply_prob
|
||||
self.noise_apply_prob = noise_apply_prob
|
||||
self.split_with_space = split_with_space
|
||||
self.seg_dict = None
|
||||
if seg_dict_file is not None:
|
||||
self.seg_dict = {}
|
||||
with open(seg_dict_file) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
s = line.strip().split()
|
||||
key = s[0]
|
||||
value = s[1:]
|
||||
self.seg_dict[key] = " ".join(value)
|
||||
|
||||
if token_type is not None:
|
||||
if token_list is None:
|
||||
raise ValueError("token_list is required if token_type is not None")
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
|
||||
self.tokenizer = build_tokenizer(
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
delimiter=delimiter,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
g2p_type=g2p_type,
|
||||
)
|
||||
self.token_id_converter = TokenIDConverter(
|
||||
token_list=token_list,
|
||||
unk_symbol=unk_symbol,
|
||||
)
|
||||
else:
|
||||
self.text_cleaner = None
|
||||
self.tokenizer = None
|
||||
self.token_id_converter = None
|
||||
|
||||
if train and rir_scp is not None:
|
||||
self.rirs = []
|
||||
with open(rir_scp, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
sps = line.strip().split(None, 1)
|
||||
if len(sps) == 1:
|
||||
self.rirs.append(sps[0])
|
||||
else:
|
||||
self.rirs.append(sps[1])
|
||||
else:
|
||||
self.rirs = None
|
||||
|
||||
if train and noise_scp is not None:
|
||||
self.noises = []
|
||||
with open(noise_scp, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
sps = line.strip().split(None, 1)
|
||||
if len(sps) == 1:
|
||||
self.noises.append(sps[0])
|
||||
else:
|
||||
self.noises.append(sps[1])
|
||||
sps = noise_db_range.split("_")
|
||||
if len(sps) == 1:
|
||||
self.noise_db_low, self.noise_db_high = float(sps[0])
|
||||
elif len(sps) == 2:
|
||||
self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
"Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
|
||||
)
|
||||
else:
|
||||
self.noises = None
|
||||
|
||||
def _speech_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, Union[str, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
if self.speech_name in data:
|
||||
if self.train and (self.rirs is not None or self.noises is not None):
|
||||
speech = data[self.speech_name]
|
||||
nsamples = len(speech)
|
||||
|
||||
# speech: (Nmic, Time)
|
||||
if speech.ndim == 1:
|
||||
speech = speech[None, :]
|
||||
else:
|
||||
speech = speech.T
|
||||
# Calc power on non shlence region
|
||||
power = (speech[detect_non_silence(speech)] ** 2).mean()
|
||||
|
||||
# 1. Convolve RIR
|
||||
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
|
||||
rir_path = np.random.choice(self.rirs)
|
||||
if rir_path is not None:
|
||||
rir, _ = soundfile.read(
|
||||
rir_path, dtype=np.float64, always_2d=True
|
||||
)
|
||||
|
||||
# rir: (Nmic, Time)
|
||||
rir = rir.T
|
||||
|
||||
# speech: (Nmic, Time)
|
||||
# Note that this operation doesn't change the signal length
|
||||
speech = scipy.signal.convolve(speech, rir, mode="full")[
|
||||
:, : speech.shape[1]
|
||||
]
|
||||
# Reverse mean power to the original power
|
||||
power2 = (speech[detect_non_silence(speech)] ** 2).mean()
|
||||
speech = np.sqrt(power / max(power2, 1e-10)) * speech
|
||||
|
||||
# 2. Add Noise
|
||||
if (
|
||||
self.noises is not None
|
||||
and self.noise_apply_prob >= np.random.random()
|
||||
):
|
||||
noise_path = np.random.choice(self.noises)
|
||||
if noise_path is not None:
|
||||
noise_db = np.random.uniform(
|
||||
self.noise_db_low, self.noise_db_high
|
||||
)
|
||||
with soundfile.SoundFile(noise_path) as f:
|
||||
if f.frames == nsamples:
|
||||
noise = f.read(dtype=np.float64, always_2d=True)
|
||||
elif f.frames < nsamples:
|
||||
offset = np.random.randint(0, nsamples - f.frames)
|
||||
# noise: (Time, Nmic)
|
||||
noise = f.read(dtype=np.float64, always_2d=True)
|
||||
# Repeat noise
|
||||
noise = np.pad(
|
||||
noise,
|
||||
[(offset, nsamples - f.frames - offset), (0, 0)],
|
||||
mode="wrap",
|
||||
)
|
||||
else:
|
||||
offset = np.random.randint(0, f.frames - nsamples)
|
||||
f.seek(offset)
|
||||
# noise: (Time, Nmic)
|
||||
noise = f.read(
|
||||
nsamples, dtype=np.float64, always_2d=True
|
||||
)
|
||||
if len(noise) != nsamples:
|
||||
raise RuntimeError(f"Something wrong: {noise_path}")
|
||||
# noise: (Nmic, Time)
|
||||
noise = noise.T
|
||||
|
||||
noise_power = (noise ** 2).mean()
|
||||
scale = (
|
||||
10 ** (-noise_db / 20)
|
||||
* np.sqrt(power)
|
||||
/ np.sqrt(max(noise_power, 1e-10))
|
||||
)
|
||||
speech = speech + scale * noise
|
||||
|
||||
speech = speech.T
|
||||
ma = np.max(np.abs(speech))
|
||||
if ma > 1.0:
|
||||
speech /= ma
|
||||
data[self.speech_name] = speech
|
||||
|
||||
if self.speech_volume_normalize is not None:
|
||||
speech = data[self.speech_name]
|
||||
ma = np.max(np.abs(speech))
|
||||
data[self.speech_name] = speech * self.speech_volume_normalize / ma
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def _text_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
if self.text_name in data and self.tokenizer is not None:
|
||||
text = data[self.text_name]
|
||||
text = self.text_cleaner(text)
|
||||
if self.split_with_space:
|
||||
tokens = text.strip().split(" ")
|
||||
if self.seg_dict is not None:
|
||||
tokens = seg_tokenize(tokens, self.seg_dict)
|
||||
else:
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[self.text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
assert check_argument_types()
|
||||
|
||||
data = self._speech_process(data)
|
||||
data = self._text_process(data)
|
||||
return data
|
||||
|
||||
## FIXME
|
||||
class LMPreprocessor(CommonPreprocessor):
|
||||
def __init__(
|
||||
self,
|
||||
train: bool,
|
||||
token_type: str = None,
|
||||
token_list: Union[Path, str, Iterable[str]] = None,
|
||||
bpemodel: Union[Path, str, Iterable[str]] = None,
|
||||
text_cleaner: Collection[str] = None,
|
||||
g2p_type: str = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
space_symbol: str = "<space>",
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
delimiter: str = None,
|
||||
rir_scp: str = None,
|
||||
rir_apply_prob: float = 1.0,
|
||||
noise_scp: str = None,
|
||||
noise_apply_prob: float = 1.0,
|
||||
noise_db_range: str = "3_10",
|
||||
speech_volume_normalize: float = None,
|
||||
speech_name: str = "speech",
|
||||
text_name: str = "text",
|
||||
split_with_space: bool = False,
|
||||
seg_dict_file: str = None,
|
||||
):
|
||||
super().__init__(train,
|
||||
token_type,
|
||||
token_list,
|
||||
bpemodel,
|
||||
text_cleaner,
|
||||
g2p_type,
|
||||
unk_symbol,
|
||||
space_symbol,
|
||||
non_linguistic_symbols,
|
||||
delimiter,
|
||||
rir_scp,
|
||||
rir_apply_prob,
|
||||
noise_scp,
|
||||
noise_apply_prob,
|
||||
noise_db_range,
|
||||
speech_volume_normalize,
|
||||
speech_name,
|
||||
text_name,
|
||||
split_with_space,
|
||||
seg_dict_file,
|
||||
)
|
||||
|
||||
def _text_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
if self.text_name in data and self.tokenizer is not None:
|
||||
text = data[self.text_name]
|
||||
text = self.text_cleaner(text)
|
||||
if self.split_with_space:
|
||||
tokens = text.strip().split(" ")
|
||||
if self.seg_dict is not None:
|
||||
tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
|
||||
else:
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[self.text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
|
||||
class CommonPreprocessor_multi(AbsPreprocessor):
|
||||
def __init__(
|
||||
self,
|
||||
train: bool,
|
||||
token_type: str = None,
|
||||
token_list: Union[Path, str, Iterable[str]] = None,
|
||||
bpemodel: Union[Path, str, Iterable[str]] = None,
|
||||
text_cleaner: Collection[str] = None,
|
||||
g2p_type: str = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
space_symbol: str = "<space>",
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
delimiter: str = None,
|
||||
speech_name: str = "speech",
|
||||
text_name: List[str] = ["text"],
|
||||
):
|
||||
super().__init__(train)
|
||||
self.train = train
|
||||
self.speech_name = speech_name
|
||||
self.text_name = text_name
|
||||
|
||||
if token_type is not None:
|
||||
if token_list is None:
|
||||
raise ValueError("token_list is required if token_type is not None")
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
|
||||
self.tokenizer = build_tokenizer(
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
delimiter=delimiter,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
g2p_type=g2p_type,
|
||||
)
|
||||
self.token_id_converter = TokenIDConverter(
|
||||
token_list=token_list,
|
||||
unk_symbol=unk_symbol,
|
||||
)
|
||||
else:
|
||||
self.text_cleaner = None
|
||||
self.tokenizer = None
|
||||
self.token_id_converter = None
|
||||
|
||||
def _text_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
for text_n in self.text_name:
|
||||
if text_n in data and self.tokenizer is not None:
|
||||
text = data[text_n]
|
||||
text = self.text_cleaner(text)
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[text_n] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
assert check_argument_types()
|
||||
|
||||
if self.speech_name in data:
|
||||
# Nothing now: candidates:
|
||||
# - STFT
|
||||
# - Fbank
|
||||
# - CMVN
|
||||
# - Data augmentation
|
||||
pass
|
||||
|
||||
data = self._text_process(data)
|
||||
return data
|
||||
|
||||
|
||||
class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
def __init__(
|
||||
self,
|
||||
train: bool,
|
||||
token_type: List[str] = [None],
|
||||
token_list: List[Union[Path, str, Iterable[str]]] = [None],
|
||||
bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
|
||||
text_cleaner: Collection[str] = None,
|
||||
g2p_type: str = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
space_symbol: str = "<space>",
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
delimiter: str = None,
|
||||
rir_scp: str = None,
|
||||
rir_apply_prob: float = 1.0,
|
||||
noise_scp: str = None,
|
||||
noise_apply_prob: float = 1.0,
|
||||
noise_db_range: str = "3_10",
|
||||
speech_volume_normalize: float = None,
|
||||
speech_name: str = "speech",
|
||||
text_name: List[str] = ["text"],
|
||||
):
|
||||
# TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
|
||||
super().__init__(
|
||||
train=train,
|
||||
token_type=token_type[0],
|
||||
token_list=token_list[0],
|
||||
bpemodel=bpemodel[0],
|
||||
text_cleaner=text_cleaner,
|
||||
g2p_type=g2p_type,
|
||||
unk_symbol=unk_symbol,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
delimiter=delimiter,
|
||||
speech_name=speech_name,
|
||||
text_name=text_name[0],
|
||||
rir_scp=rir_scp,
|
||||
rir_apply_prob=rir_apply_prob,
|
||||
noise_scp=noise_scp,
|
||||
noise_apply_prob=noise_apply_prob,
|
||||
noise_db_range=noise_db_range,
|
||||
speech_volume_normalize=speech_volume_normalize,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
|
||||
), "token_type, token_list, bpemodel, or processing text_name mismatched"
|
||||
self.num_tokenizer = len(token_type)
|
||||
self.tokenizer = []
|
||||
self.token_id_converter = []
|
||||
|
||||
for i in range(self.num_tokenizer):
|
||||
if token_type[i] is not None:
|
||||
if token_list[i] is None:
|
||||
raise ValueError("token_list is required if token_type is not None")
|
||||
|
||||
self.tokenizer.append(
|
||||
build_tokenizer(
|
||||
token_type=token_type[i],
|
||||
bpemodel=bpemodel[i],
|
||||
delimiter=delimiter,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
g2p_type=g2p_type,
|
||||
)
|
||||
)
|
||||
self.token_id_converter.append(
|
||||
TokenIDConverter(
|
||||
token_list=token_list[i],
|
||||
unk_symbol=unk_symbol,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.tokenizer.append(None)
|
||||
self.token_id_converter.append(None)
|
||||
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
self.text_name = text_name # override the text_name from CommonPreprocessor
|
||||
|
||||
def _text_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
for i in range(self.num_tokenizer):
|
||||
text_name = self.text_name[i]
|
||||
if text_name in data and self.tokenizer[i] is not None:
|
||||
text = data[text_name]
|
||||
text = self.text_cleaner(text)
|
||||
tokens = self.tokenizer[i].text2tokens(text)
|
||||
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
||||
data[text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
def __init__(
|
||||
self,
|
||||
train: bool,
|
||||
token_type: str = None,
|
||||
token_list: Union[Path, str, Iterable[str]] = None,
|
||||
bpemodel: Union[Path, str, Iterable[str]] = None,
|
||||
text_cleaner: Collection[str] = None,
|
||||
g2p_type: str = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
space_symbol: str = "<space>",
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
delimiter: str = None,
|
||||
rir_scp: str = None,
|
||||
rir_apply_prob: float = 1.0,
|
||||
noise_scp: str = None,
|
||||
noise_apply_prob: float = 1.0,
|
||||
noise_db_range: str = "3_10",
|
||||
speech_volume_normalize: float = None,
|
||||
speech_name: str = "speech",
|
||||
text_name: str = "text",
|
||||
split_text_name: str = "split_text",
|
||||
split_with_space: bool = False,
|
||||
seg_dict_file: str = None,
|
||||
):
|
||||
super().__init__(
|
||||
train=train,
|
||||
# Force to use word.
|
||||
token_type="word",
|
||||
token_list=token_list,
|
||||
bpemodel=bpemodel,
|
||||
text_cleaner=text_cleaner,
|
||||
g2p_type=g2p_type,
|
||||
unk_symbol=unk_symbol,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
delimiter=delimiter,
|
||||
speech_name=speech_name,
|
||||
text_name=text_name,
|
||||
rir_scp=rir_scp,
|
||||
rir_apply_prob=rir_apply_prob,
|
||||
noise_scp=noise_scp,
|
||||
noise_apply_prob=noise_apply_prob,
|
||||
noise_db_range=noise_db_range,
|
||||
speech_volume_normalize=speech_volume_normalize,
|
||||
split_with_space=split_with_space,
|
||||
seg_dict_file=seg_dict_file,
|
||||
)
|
||||
# The data field name for split text.
|
||||
self.split_text_name = split_text_name
|
||||
|
||||
@classmethod
|
||||
def split_words(cls, text: str):
|
||||
words = []
|
||||
segs = text.split()
|
||||
for seg in segs:
|
||||
# There is no space in seg.
|
||||
current_word = ""
|
||||
for c in seg:
|
||||
if len(c.encode()) == 1:
|
||||
# This is an ASCII char.
|
||||
current_word += c
|
||||
else:
|
||||
# This is a Chinese char.
|
||||
if len(current_word) > 0:
|
||||
words.append(current_word)
|
||||
current_word = ""
|
||||
words.append(c)
|
||||
if len(current_word) > 0:
|
||||
words.append(current_word)
|
||||
return words
|
||||
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
|
||||
) -> Dict[str, Union[list, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
# Split words.
|
||||
if isinstance(data[self.text_name], str):
|
||||
split_text = self.split_words(data[self.text_name])
|
||||
else:
|
||||
split_text = data[self.text_name]
|
||||
data[self.text_name] = " ".join(split_text)
|
||||
data = self._speech_process(data)
|
||||
data = self._text_process(data)
|
||||
data[self.split_text_name] = split_text
|
||||
return data
|
||||
|
||||
def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
|
||||
result = data[self.split_text_name]
|
||||
del data[self.split_text_name]
|
||||
return result
|
||||
|
||||
class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
def __init__(
|
||||
self,
|
||||
train: bool,
|
||||
token_type: List[str] = [None],
|
||||
token_list: List[Union[Path, str, Iterable[str]]] = [None],
|
||||
bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
|
||||
text_cleaner: Collection[str] = None,
|
||||
g2p_type: str = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
space_symbol: str = "<space>",
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
delimiter: str = None,
|
||||
rir_scp: str = None,
|
||||
rir_apply_prob: float = 1.0,
|
||||
noise_scp: str = None,
|
||||
noise_apply_prob: float = 1.0,
|
||||
noise_db_range: str = "3_10",
|
||||
speech_volume_normalize: float = None,
|
||||
speech_name: str = "speech",
|
||||
text_name: List[str] = ["text"],
|
||||
vad_name: str = "vad_indexes",
|
||||
):
|
||||
# TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
|
||||
super().__init__(
|
||||
train=train,
|
||||
token_type=token_type[0],
|
||||
token_list=token_list[0],
|
||||
bpemodel=bpemodel[0],
|
||||
text_cleaner=text_cleaner,
|
||||
g2p_type=g2p_type,
|
||||
unk_symbol=unk_symbol,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
delimiter=delimiter,
|
||||
speech_name=speech_name,
|
||||
text_name=text_name[0],
|
||||
rir_scp=rir_scp,
|
||||
rir_apply_prob=rir_apply_prob,
|
||||
noise_scp=noise_scp,
|
||||
noise_apply_prob=noise_apply_prob,
|
||||
noise_db_range=noise_db_range,
|
||||
speech_volume_normalize=speech_volume_normalize,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
|
||||
), "token_type, token_list, bpemodel, or processing text_name mismatched"
|
||||
self.num_tokenizer = len(token_type)
|
||||
self.tokenizer = []
|
||||
self.token_id_converter = []
|
||||
|
||||
for i in range(self.num_tokenizer):
|
||||
if token_type[i] is not None:
|
||||
if token_list[i] is None:
|
||||
raise ValueError("token_list is required if token_type is not None")
|
||||
|
||||
self.tokenizer.append(
|
||||
build_tokenizer(
|
||||
token_type=token_type[i],
|
||||
bpemodel=bpemodel[i],
|
||||
delimiter=delimiter,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
g2p_type=g2p_type,
|
||||
)
|
||||
)
|
||||
self.token_id_converter.append(
|
||||
TokenIDConverter(
|
||||
token_list=token_list[i],
|
||||
unk_symbol=unk_symbol,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.tokenizer.append(None)
|
||||
self.token_id_converter.append(None)
|
||||
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
self.text_name = text_name # override the text_name from CommonPreprocessor
|
||||
self.vad_name = vad_name
|
||||
|
||||
def _text_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
for i in range(self.num_tokenizer):
|
||||
text_name = self.text_name[i]
|
||||
#import pdb; pdb.set_trace()
|
||||
if text_name in data and self.tokenizer[i] is not None:
|
||||
text = data[text_name]
|
||||
text = self.text_cleaner(text)
|
||||
tokens = self.tokenizer[i].text2tokens(text)
|
||||
if "vad:" in tokens[-1]:
|
||||
vad = tokens[-1][4:]
|
||||
tokens = tokens[:-1]
|
||||
if len(vad) == 0:
|
||||
vad = -1
|
||||
else:
|
||||
vad = int(vad)
|
||||
data[self.vad_name] = np.array([vad], dtype=np.int64)
|
||||
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
||||
data[text_name] = np.array(text_ints, dtype=np.int64)
|
||||
return data
|
||||
|
||||
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
||||
assert word_limit > 1
|
||||
if len(words) <= word_limit:
|
||||
return [words]
|
||||
sentences = []
|
||||
length = len(words)
|
||||
sentence_len = length // word_limit
|
||||
for i in range(sentence_len):
|
||||
sentences.append(words[i * word_limit:(i + 1) * word_limit])
|
||||
if length % word_limit > 0:
|
||||
sentences.append(words[sentence_len * word_limit:])
|
||||
return sentences
|
||||
Reference in New Issue
Block a user