mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
168
funasr_local/samplers/build_batch_sampler.py
Normal file
168
funasr_local/samplers/build_batch_sampler.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr_local.samplers.abs_sampler import AbsSampler
|
||||
from funasr_local.samplers.folded_batch_sampler import FoldedBatchSampler
|
||||
from funasr_local.samplers.length_batch_sampler import LengthBatchSampler
|
||||
from funasr_local.samplers.num_elements_batch_sampler import NumElementsBatchSampler
|
||||
from funasr_local.samplers.sorted_batch_sampler import SortedBatchSampler
|
||||
from funasr_local.samplers.unsorted_batch_sampler import UnsortedBatchSampler
|
||||
|
||||
|
||||
BATCH_TYPES = dict(
|
||||
unsorted="UnsortedBatchSampler has nothing in particular feature and "
|
||||
"just creates mini-batches which has constant batch_size. "
|
||||
"This sampler doesn't require any length "
|
||||
"information for each feature. "
|
||||
"'key_file' is just a text file which describes each sample name."
|
||||
"\n\n"
|
||||
" utterance_id_a\n"
|
||||
" utterance_id_b\n"
|
||||
" utterance_id_c\n"
|
||||
"\n"
|
||||
"The fist column is referred, so 'shape file' can be used, too.\n\n"
|
||||
" utterance_id_a 100,80\n"
|
||||
" utterance_id_b 400,80\n"
|
||||
" utterance_id_c 512,80\n",
|
||||
sorted="SortedBatchSampler sorts samples by the length of the first input "
|
||||
" in order to make each sample in a mini-batch has close length. "
|
||||
"This sampler requires a text file which describes the length for each sample "
|
||||
"\n\n"
|
||||
" utterance_id_a 1000\n"
|
||||
" utterance_id_b 1453\n"
|
||||
" utterance_id_c 1241\n"
|
||||
"\n"
|
||||
"The first element of feature dimensions is referred, "
|
||||
"so 'shape_file' can be also used.\n\n"
|
||||
" utterance_id_a 1000,80\n"
|
||||
" utterance_id_b 1453,80\n"
|
||||
" utterance_id_c 1241,80\n",
|
||||
folded="FoldedBatchSampler supports variable batch_size. "
|
||||
"The batch_size is decided by\n"
|
||||
" batch_size = base_batch_size // (L // fold_length)\n"
|
||||
"L is referred to the largest length of samples in the mini-batch. "
|
||||
"This samples requires length information as same as SortedBatchSampler\n",
|
||||
length="LengthBatchSampler supports variable batch_size. "
|
||||
"This sampler makes mini-batches which have same number of 'bins' as possible "
|
||||
"counting by the total lengths of each feature in the mini-batch. "
|
||||
"This sampler requires a text file which describes the length for each sample. "
|
||||
"\n\n"
|
||||
" utterance_id_a 1000\n"
|
||||
" utterance_id_b 1453\n"
|
||||
" utterance_id_c 1241\n"
|
||||
"\n"
|
||||
"The first element of feature dimensions is referred, "
|
||||
"so 'shape_file' can be also used.\n\n"
|
||||
" utterance_id_a 1000,80\n"
|
||||
" utterance_id_b 1453,80\n"
|
||||
" utterance_id_c 1241,80\n",
|
||||
numel="NumElementsBatchSampler supports variable batch_size. "
|
||||
"Just like LengthBatchSampler, this sampler makes mini-batches"
|
||||
" which have same number of 'bins' as possible "
|
||||
"counting by the total number of elements of each feature "
|
||||
"instead of the length. "
|
||||
"Thus this sampler requires the full information of the dimension of the features. "
|
||||
"\n\n"
|
||||
" utterance_id_a 1000,80\n"
|
||||
" utterance_id_b 1453,80\n"
|
||||
" utterance_id_c 1241,80\n",
|
||||
)
|
||||
|
||||
|
||||
def build_batch_sampler(
|
||||
type: str,
|
||||
batch_size: int,
|
||||
batch_bins: int,
|
||||
shape_files: Union[Tuple[str, ...], List[str], Dict],
|
||||
sort_in_batch: str = "descending",
|
||||
sort_batch: str = "ascending",
|
||||
drop_last: bool = False,
|
||||
min_batch_size: int = 1,
|
||||
fold_lengths: Sequence[int] = (),
|
||||
padding: bool = True,
|
||||
utt2category_file: str = None,
|
||||
) -> AbsSampler:
|
||||
"""Helper function to instantiate BatchSampler.
|
||||
|
||||
Args:
|
||||
type: mini-batch type. "unsorted", "sorted", "folded", "numel", or, "length"
|
||||
batch_size: The mini-batch size. Used for "unsorted", "sorted", "folded" mode
|
||||
batch_bins: Used for "numel" model
|
||||
shape_files: Text files describing the length and dimension
|
||||
of each features. e.g. uttA 1330,80
|
||||
sort_in_batch:
|
||||
sort_batch:
|
||||
drop_last:
|
||||
min_batch_size: Used for "numel" or "folded" mode
|
||||
fold_lengths: Used for "folded" mode
|
||||
padding: Whether sequences are input as a padded tensor or not.
|
||||
used for "numel" mode
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if len(shape_files) == 0:
|
||||
raise ValueError("No shape file are given")
|
||||
|
||||
if type == "unsorted":
|
||||
retval = UnsortedBatchSampler(
|
||||
batch_size=batch_size, key_file=shape_files[0], drop_last=drop_last
|
||||
)
|
||||
|
||||
elif type == "sorted":
|
||||
retval = SortedBatchSampler(
|
||||
batch_size=batch_size,
|
||||
shape_file=shape_files[0],
|
||||
sort_in_batch=sort_in_batch,
|
||||
sort_batch=sort_batch,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
elif type == "folded":
|
||||
if len(fold_lengths) != len(shape_files):
|
||||
raise ValueError(
|
||||
f"The number of fold_lengths must be equal to "
|
||||
f"the number of shape_files: "
|
||||
f"{len(fold_lengths)} != {len(shape_files)}"
|
||||
)
|
||||
retval = FoldedBatchSampler(
|
||||
batch_size=batch_size,
|
||||
shape_files=shape_files,
|
||||
fold_lengths=fold_lengths,
|
||||
sort_in_batch=sort_in_batch,
|
||||
sort_batch=sort_batch,
|
||||
drop_last=drop_last,
|
||||
min_batch_size=min_batch_size,
|
||||
utt2category_file=utt2category_file,
|
||||
)
|
||||
|
||||
elif type == "numel":
|
||||
retval = NumElementsBatchSampler(
|
||||
batch_bins=batch_bins,
|
||||
shape_files=shape_files,
|
||||
sort_in_batch=sort_in_batch,
|
||||
sort_batch=sort_batch,
|
||||
drop_last=drop_last,
|
||||
padding=padding,
|
||||
min_batch_size=min_batch_size,
|
||||
)
|
||||
|
||||
elif type == "length":
|
||||
retval = LengthBatchSampler(
|
||||
batch_bins=batch_bins,
|
||||
shape_files=shape_files,
|
||||
sort_in_batch=sort_in_batch,
|
||||
sort_batch=sort_batch,
|
||||
drop_last=drop_last,
|
||||
padding=padding,
|
||||
min_batch_size=min_batch_size,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Not supported: {type}")
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
Reference in New Issue
Block a user