add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

8
funasr_local/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""Initialize funasr_local package."""
import os
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
__version__ = f.read().strip()

View File

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
from pathlib import Path
from typing import Iterable
from typing import Union
import numpy as np
from funasr_local.utils.cli_utils import get_commandline_args
def aggregate_stats_dirs(
input_dir: Iterable[Union[str, Path]],
output_dir: Union[str, Path],
log_level: str,
skip_sum_stats: bool,
):
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) (levelname)s: %(message)s",
)
input_dirs = [Path(p) for p in input_dir]
output_dir = Path(output_dir)
for mode in ["train", "valid"]:
with (input_dirs[0] / mode / "batch_keys").open("r", encoding="utf-8") as f:
batch_keys = [line.strip() for line in f if line.strip() != ""]
with (input_dirs[0] / mode / "stats_keys").open("r", encoding="utf-8") as f:
stats_keys = [line.strip() for line in f if line.strip() != ""]
(output_dir / mode).mkdir(parents=True, exist_ok=True)
for key in batch_keys:
with (output_dir / mode / f"{key}_shape").open(
"w", encoding="utf-8"
) as fout:
for idir in input_dirs:
with (idir / mode / f"{key}_shape").open(
"r", encoding="utf-8"
) as fin:
# Read to the last in order to sort keys
# because the order can be changed if num_workers>=1
lines = fin.readlines()
lines = sorted(lines, key=lambda x: x.split()[0])
for line in lines:
fout.write(line)
for key in stats_keys:
if not skip_sum_stats:
sum_stats = None
for idir in input_dirs:
stats = np.load(idir / mode / f"{key}_stats.npz")
if sum_stats is None:
sum_stats = dict(**stats)
else:
for k in stats:
sum_stats[k] += stats[k]
np.savez(output_dir / mode / f"{key}_stats.npz", **sum_stats)
# if --write_collected_feats=true
p = Path(mode) / "collect_feats" / f"{key}.scp"
scp = input_dirs[0] / p
if scp.exists():
(output_dir / p).parent.mkdir(parents=True, exist_ok=True)
with (output_dir / p).open("w", encoding="utf-8") as fout:
for idir in input_dirs:
with (idir / p).open("r", encoding="utf-8") as fin:
for line in fin:
fout.write(line)
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Aggregate statistics directories into one directory",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument(
"--skip_sum_stats",
default=False,
action="store_true",
help="Skip computing the sum of statistics.",
)
parser.add_argument("--input_dir", action="append", help="Input directories")
parser.add_argument("--output_dir", required=True, help="Output directory")
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
aggregate_stats_dirs(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,640 @@
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.modules.beam_search.batch_beam_search import BatchBeamSearch
from funasr_local.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
from funasr_local.modules.beam_search.beam_search import BeamSearch
from funasr_local.modules.beam_search.beam_search import Hypothesis
from funasr_local.modules.scorers.ctc import CTCPrefixScorer
from funasr_local.modules.scorers.length_bonus import LengthBonus
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.tasks.asr import ASRTask
from funasr_local.tasks.lm import LMTask
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils import asr_utils, wav_utils, postprocess_utils
from funasr_local.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
batch_size: int = 1,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
streaming: bool = False,
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
decoder = asr_model.decoder
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,345 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--njob",
type=int,
default=1,
help="The number of jobs for each gpu",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global CMVN file",
)
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group.add_argument(
"--beam_search_config",
default={},
help="The keyword arguments for transducer beam search.",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.0,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group.add_argument("--simu_streaming", type=str2bool, default=False)
group.add_argument("--chunk_size", type=int, default=16)
group.add_argument("--left_context", type=int, default=16)
group.add_argument("--right_context", type=int, default=0)
group.add_argument(
"--display_partial_hypotheses",
type=bool,
default=False,
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
)
group = parser.add_argument_group("Dynamic quantization related")
group.add_argument(
"--quantize_asr_model",
type=bool,
default=False,
help="Apply dynamic quantization to ASR model.",
)
group.add_argument(
"--quantize_modules",
nargs="*",
default=None,
help="""Module names to apply dynamic quantization on.
The module names are provided as a list, where each name is separated
by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
Each specified name should be an attribute of 'torch.nn', e.g.:
torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
)
group.add_argument(
"--quantize_dtype",
type=str,
default="qint8",
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
group.add_argument("--token_num_relax", type=int, default=1, help="")
group.add_argument("--decoding_ind", type=int, default=0, help="")
group.add_argument("--decoding_mode", type=str, default="model1", help="")
group.add_argument(
"--ctc_weight2",
type=float,
default=0.0,
help="CTC weight in joint decoding",
)
return parser
def inference_launch(**kwargs):
if 'mode' in kwargs:
mode = kwargs['mode']
else:
logging.info("Unknown decoding mode.")
return None
if mode == "asr":
from funasr_local.bin.asr_inference import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "uniasr":
from funasr_local.bin.asr_inference_uniasr import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "uniasr_vad":
from funasr_local.bin.asr_inference_uniasr_vad import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer":
from funasr_local.bin.asr_inference_paraformer import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer_streaming":
from funasr_local.bin.asr_inference_paraformer_streaming import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer_vad":
from funasr_local.bin.asr_inference_paraformer_vad import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer_punc":
logging.info("Unknown decoding mode: {}".format(mode))
return None
elif mode == "paraformer_vad_punc":
from funasr_local.bin.asr_inference_paraformer_vad_punc import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "vad":
from funasr_local.bin.vad_inference import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "mfcca":
from funasr_local.bin.asr_inference_mfcca import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "rnnt":
from funasr_local.bin.asr_inference_rnnt import inference_modelscope
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def inference_launch_funasr_local(**kwargs):
if 'mode' in kwargs:
mode = kwargs['mode']
else:
logging.info("Unknown decoding mode.")
return None
if mode == "asr":
from funasr_local.bin.asr_inference import inference
return inference(**kwargs)
elif mode == "uniasr":
from funasr_local.bin.asr_inference_uniasr import inference
return inference(**kwargs)
elif mode == "paraformer":
from funasr_local.bin.asr_inference_paraformer import inference
return inference(**kwargs)
elif mode == "paraformer_vad_punc":
from funasr_local.bin.asr_inference_paraformer_vad_punc import inference
return inference(**kwargs)
elif mode == "vad":
from funasr_local.bin.vad_inference import inference
return inference(**kwargs)
elif mode == "mfcca":
from funasr_local.bin.asr_inference_mfcca import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "rnnt":
from funasr_local.bin.asr_inference_rnnt import inference
return inference(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
parser.add_argument(
"--mode",
type=str,
default="asr",
help="The decoding mode",
)
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
inference_launch_funasr_local(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,767 @@
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.modules.beam_search.batch_beam_search import BatchBeamSearch
from funasr_local.modules.beam_search.beam_search import BeamSearch
from funasr_local.modules.beam_search.beam_search import Hypothesis
from funasr_local.modules.scorers.ctc import CTCPrefixScorer
from funasr_local.modules.scorers.length_bonus import LengthBonus
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.tasks.asr import ASRTaskMFCCA as ASRTask
from funasr_local.tasks.lm import LMTask
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils import asr_utils, wav_utils, postprocess_utils
import pdb
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
'audio_fs': 16000,
'model_fs': 16000
}
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
batch_size: int = 1,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
streaming: bool = False,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
decoder = asr_model.decoder
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
lm.to(device)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
#beam_search.__class__ = BatchBeamSearch
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if(speech.dim()==3):
speech = torch.squeeze(speech, 2)
#speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
speech = speech.to(getattr(torch, self.dtype))
# lenghts: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
batch = {"speech": speech, "speech_lengths": lengths}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
assert len(enc) == 1, len(enc)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
# def inference(
# maxlenratio: float,
# minlenratio: float,
# batch_size: int,
# beam_size: int,
# ngpu: int,
# ctc_weight: float,
# lm_weight: float,
# penalty: float,
# log_level: Union[int, str],
# data_path_and_name_and_type,
# asr_train_config: Optional[str],
# asr_model_file: Optional[str],
# cmvn_file: Optional[str] = None,
# lm_train_config: Optional[str] = None,
# lm_file: Optional[str] = None,
# token_type: Optional[str] = None,
# key_file: Optional[str] = None,
# word_lm_train_config: Optional[str] = None,
# bpemodel: Optional[str] = None,
# allow_variable_data_keys: bool = False,
# streaming: bool = False,
# output_dir: Optional[str] = None,
# dtype: str = "float32",
# seed: int = 0,
# ngram_weight: float = 0.9,
# nbest: int = 1,
# num_workers: int = 1,
# **kwargs,
# ):
# assert check_argument_types()
# if batch_size > 1:
# raise NotImplementedError("batch decoding is not implemented")
# if word_lm_train_config is not None:
# raise NotImplementedError("Word LM is not implemented")
# if ngpu > 1:
# raise NotImplementedError("only single GPU decoding is supported")
#
# logging.basicConfig(
# level=log_level,
# format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
# )
#
# if ngpu >= 1 and torch.cuda.is_available():
# device = "cuda"
# else:
# device = "cpu"
#
# # 1. Set random-seed
# set_all_random_seed(seed)
#
# # 2. Build speech2text
# speech2text_kwargs = dict(
# asr_train_config=asr_train_config,
# asr_model_file=asr_model_file,
# cmvn_file=cmvn_file,
# lm_train_config=lm_train_config,
# lm_file=lm_file,
# token_type=token_type,
# bpemodel=bpemodel,
# device=device,
# maxlenratio=maxlenratio,
# minlenratio=minlenratio,
# dtype=dtype,
# beam_size=beam_size,
# ctc_weight=ctc_weight,
# lm_weight=lm_weight,
# ngram_weight=ngram_weight,
# penalty=penalty,
# nbest=nbest,
# streaming=streaming,
# )
# logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
# speech2text = Speech2Text(**speech2text_kwargs)
#
# # 3. Build data-iterator
# loader = ASRTask.build_streaming_iterator(
# data_path_and_name_and_type,
# dtype=dtype,
# batch_size=batch_size,
# key_file=key_file,
# num_workers=num_workers,
# preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
# collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
# allow_variable_data_keys=allow_variable_data_keys,
# inference=True,
# )
#
# finish_count = 0
# file_count = 1
# # 7 .Start for-loop
# # FIXME(kamo): The output format should be discussed about
# asr_result_list = []
# if output_dir is not None:
# writer = DatadirWriter(output_dir)
# else:
# writer = None
#
# for keys, batch in loader:
# assert isinstance(batch, dict), type(batch)
# assert all(isinstance(s, str) for s in keys), keys
# _bs = len(next(iter(batch.values())))
# assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
#
# # N-best list of (text, token, token_int, hyp_object)
# try:
# results = speech2text(**batch)
# except TooShortUttError as e:
# logging.warning(f"Utterance {keys} {e}")
# hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
# results = [[" ", ["<space>"], [2], hyp]] * nbest
#
# # Only supporting batch_size==1
# key = keys[0]
# for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# # Create a directory: outdir/{n}best_recog
# if writer is not None:
# ibest_writer = writer[f"{n}best_recog"]
#
# # Write the result to each file
# ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
# ibest_writer["score"][key] = str(hyp.score)
#
# if text is not None:
# text_postprocessed = postprocess_utils.sentence_postprocess(token)
# item = {'key': key, 'value': text_postprocessed}
# asr_result_list.append(item)
# finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
# if writer is not None:
# ibest_writer["text"][key] = text
# return asr_result_list
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
fs=fs,
mc=True,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
text_postprocessed = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,761 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
import time
import copy
import os
import codecs
import tempfile
import requests
import yaml
from pathlib import Path
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import torch
import torchaudio
from typeguard import check_argument_types
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr_local.modules.beam_search.beam_search import Hypothesis
from funasr_local.modules.scorers.ctc import CTCPrefixScorer
from funasr_local.modules.scorers.length_bonus import LengthBonus
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.tasks.asr import ASRTaskParaformer as ASRTask
from funasr_local.tasks.lm import LMTask
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils import asr_utils, wav_utils, postprocess_utils
from funasr_local.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr_local.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
np.set_printoptions(threshold=np.inf)
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
hotword_list_or_file: str = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
# 6. [Optional] Build hotword list from str, local file or url
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
self.beam_search = beam_search
logging.info(f"Beam_search: {self.beam_search}")
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
self.encoder_downsampling_factor = 1
if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
self.encoder_downsampling_factor = 4
@torch.no_grad()
def __call__(
self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None
):
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
results = []
cache_en = cache["encoder"]
if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
if cache_en["start_idx"] == 0:
return []
cache_en["tail_chunk"] = True
feats = cache_en["feats"]
feats_len = torch.tensor([feats.shape[1]])
self.asr_model.frontend = None
results = self.infer(feats, feats_len, cache)
return results
else:
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
if feats.shape[1] != 0:
if cache_en["is_final"]:
if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]:
cache_en["last_chunk"] = True
else:
# first chunk
feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :]
feats_len = torch.tensor([feats_chunk1.shape[1]])
results_chunk1 = self.infer(feats_chunk1, feats_len, cache)
# last chunk
cache_en["last_chunk"] = True
feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :]
feats_len = torch.tensor([feats_chunk2.shape[1]])
results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
return ["".join(results_chunk1 + results_chunk2)]
results = self.infer(feats, feats_len, cache)
return results
@torch.no_grad()
def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None):
batch = {"speech": feats, "speech_lengths": feats_len}
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache)
if isinstance(enc, tuple):
enc = enc[0]
# assert len(enc) == 1, len(enc)
enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1]
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
decoder_out = decoder_outs
results = []
b, n, d = decoder_out.size()
for i in range(b):
x = enc[i, :enc_len[i], :]
am_scores = decoder_out[i, :pre_token_length[i], :]
if self.beam_search is not None:
nbest_hyps = self.beam_search(
x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
else:
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append(text)
# assert check_return_type(results)
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
export_mode = False
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
batch_size = 1
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
)
speech2text = Speech2Text(**speech2text_kwargs)
def _load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
raise FileExistsError(f'The {yaml_path} does not exist.')
with open(str(yaml_path), 'rb') as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
if len(cache) > 0:
return cache
config = _read_yaml(asr_train_config)
enc_output_size = config["encoder_conf"]["output_size"]
feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
cache["encoder"] = cache_en
cache_de = {"decode_fsmn": None}
cache["decoder"] = cache_de
return cache
def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
if len(cache) > 0:
config = _read_yaml(asr_train_config)
enc_output_size = config["encoder_conf"]["output_size"]
feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
cache["encoder"] = cache_en
cache_de = {"decode_fsmn": None}
cache["decoder"] = cache_de
return cache
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
is_final = False
cache = {}
chunk_size = [5, 10, 5]
if param_dict is not None and "cache" in param_dict:
cache = param_dict["cache"]
if param_dict is not None and "is_final" in param_dict:
is_final = param_dict["is_final"]
if param_dict is not None and "chunk_size" in param_dict:
chunk_size = param_dict["chunk_size"]
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
asr_result_list = []
cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
item = {}
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
sample_offset = 0
speech_length = raw_inputs.shape[1]
stride_size = chunk_size[1] * 960
cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
final_result = ""
for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
if sample_offset + stride_size >= speech_length - 1:
stride_size = speech_length - sample_offset
cache["encoder"]["is_final"] = True
else:
cache["encoder"]["is_final"] = False
input_lens = torch.tensor([stride_size])
asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
if len(asr_result) != 0:
final_result += asr_result[0]
item = {'key': "utt", 'value': [final_result]}
else:
input_lens = torch.tensor([raw_inputs.shape[1]])
cache["encoder"]["is_final"] = is_final
asr_result = speech2text(cache, raw_inputs, input_lens)
item = {'key': "utt", 'value': asr_result}
asr_result_list.append(item)
if is_final:
cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--hotword",
type=str_or_none,
default=None,
help="hotword file path or hotwords seperated by space"
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group.add_argument(
"--frontend_conf",
default=None,
help="",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
param_dict = {'hotword': args.hotword}
kwargs = vars(args)
kwargs.pop("config", None)
kwargs['param_dict'] = param_dict
inference(**kwargs)
if __name__ == "__main__":
main()
# from modelscope.pipelines import pipeline
# from modelscope.utils.constant import Tasks
#
# inference_16k_pipline = pipeline(
# task=Tasks.auto_speech_recognition,
# model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
#
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
# print(rec_result)

View File

@@ -0,0 +1,549 @@
#!/usr/bin/env python3
import json
import argparse
import logging
import sys
import time
from pathlib import Path
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import math
import numpy as np
import torch
from typeguard import check_argument_types
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr_local.modules.beam_search.beam_search import Hypothesis
from funasr_local.modules.scorers.ctc import CTCPrefixScorer
from funasr_local.modules.scorers.length_bonus import LengthBonus
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.tasks.asr import ASRTaskParaformer as ASRTask
from funasr_local.tasks.lm import LMTask
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils import asr_utils, wav_utils, postprocess_utils
from funasr_local.models.frontend.wav_frontend import WavFrontend
from funasr_local.tasks.vad import VADTask
from funasr_local.bin.punctuation_infer import Text2Punc
from funasr_local.bin.asr_inference_paraformer_vad_punc import Speech2Text
from funasr_local.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
vad_infer_config: Optional[str] = None,
vad_model_file: Optional[str] = None,
vad_cmvn_file: Optional[str] = None,
time_stamp_writer: bool = False,
punc_infer_config: Optional[str] = None,
punc_model_file: Optional[str] = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
time_stamp_writer=time_stamp_writer,
punc_infer_config=punc_infer_config,
punc_model_file=punc_model_file,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
vad_infer_config: Optional[str] = None,
vad_model_file: Optional[str] = None,
vad_cmvn_file: Optional[str] = None,
time_stamp_writer: bool = True,
punc_infer_config: Optional[str] = None,
punc_model_file: Optional[str] = None,
outputs_dict: Optional[bool] = True,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
else:
hotword_list_or_file = None
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2vadsegment
speech2vadsegment_kwargs = dict(
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
device=device,
dtype=dtype,
)
# logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
# 3. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
text2punc = None
if punc_model_file is not None:
text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
if output_dir is not None:
writer = DatadirWriter(output_dir)
ibest_writer = writer[f"1best_recog"]
ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
if 'hotword' in kwargs:
hotword_list_or_file = kwargs['hotword']
if speech2text.hotword_list is None:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=1,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
if param_dict is not None:
use_timestamp = param_dict.get('use_timestamp', True)
else:
use_timestamp = True
finish_count = 0
file_count = 1
lfr_factor = 6
# 7 .Start for-loop
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
writer = None
if output_path is not None:
writer = DatadirWriter(output_path)
ibest_writer = writer[f"1best_recog"]
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
vad_results = speech2vadsegment(**batch)
fbanks, vadsegments = vad_results[0], vad_results[1]
for i, segments in enumerate(vadsegments):
result_segments = [["", [], [], ]]
for j, segment_idx in enumerate(segments):
bed_idx, end_idx = int(segment_idx[0] / 10), int(segment_idx[1] / 10)
segment = fbanks[:, bed_idx:end_idx, :].to(device)
speech_lengths = torch.Tensor([end_idx - bed_idx]).int().to(device)
batch = {"speech": segment, "speech_lengths": speech_lengths, "begin_time": vadsegments[i][j][0],
"end_time": vadsegments[i][j][1]}
results = speech2text(**batch)
if len(results) < 1:
continue
result_cur = [results[0][:-2]]
if j == 0:
result_segments = result_cur
else:
result_segments = [[result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
key = keys[0]
result = result_segments[0]
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
text_postprocessed = ""
time_stamp_postprocessed = ""
text_postprocessed_punc = postprocessed_result
if len(postprocessed_result) == 3:
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
else:
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
text_postprocessed_punc = text_postprocessed
if len(word_lists) > 0 and text2punc is not None:
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
item['text_postprocessed'] = text_postprocessed
if time_stamp_postprocessed != "":
item['time_stamp'] = time_stamp_postprocessed
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["vad"][key] = "{}".format(vadsegments)
ibest_writer["text"][key] = " ".join(word_lists)
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group.add_argument("--time_stamp_writer", type=str2bool, default=False)
group.add_argument(
"--frontend_conf",
default=None,
help="",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--vad_cmvn_file",
type=str,
help="vad, Global cmvn file",
)
group.add_argument(
"--punc_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--punc_model_file",
type=str,
help="VAD model parameter file",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,881 @@
#!/usr/bin/env python3
import json
import argparse
import logging
import sys
import time
import os
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import math
import copy
import numpy as np
import torch
from typeguard import check_argument_types
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr_local.modules.beam_search.beam_search import Hypothesis
from funasr_local.modules.scorers.ctc import CTCPrefixScorer
from funasr_local.modules.scorers.length_bonus import LengthBonus
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.tasks.asr import ASRTaskParaformer as ASRTask
from funasr_local.tasks.lm import LMTask
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils import asr_utils, wav_utils, postprocess_utils
from funasr_local.models.frontend.wav_frontend import WavFrontend
from funasr_local.tasks.vad import VADTask
from funasr_local.bin.vad_inference import Speech2VadSegment
from funasr_local.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr_local.bin.punctuation_infer import Text2Punc
from funasr_local.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
hotword_list_or_file: str = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file=cmvn_file, device=device
)
frontend = None
if asr_model.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
# logging.info("asr_model: {}".format(asr_model))
# logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
# 6. [Optional] Build hotword list from str, local file or url
self.hotword_list = None
self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
self.beam_search = beam_search
logging.info(f"Beam_search: {self.beam_search}")
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
self.encoder_downsampling_factor = 1
if asr_train_args.encoder_conf["input_layer"] == "conv2d":
self.encoder_downsampling_factor = 4
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
begin_time: int = 0, end_time: int = None,
):
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
# feats, feats_len = self.frontend.forward(speech, speech_lengths)
# fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
feats, feats_len = self.frontend.forward_lfr_cmvn(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, enc_len = self.asr_model.encode(**batch)
if isinstance(enc, tuple):
enc = enc[0]
# assert len(enc) == 1, len(enc)
enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
predictor_outs[2], predictor_outs[3]
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
if not isinstance(self.asr_model, ContextualParaformer):
if self.hotword_list:
logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
else:
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
pre_token_length) # test no bias cif2
results = []
b, n, d = decoder_out.size()
for i in range(b):
x = enc[i, :enc_len[i], :]
am_scores = decoder_out[i, :pre_token_length[i], :]
if self.beam_search is not None:
nbest_hyps = self.beam_search(
x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
else:
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
if len(token_int) == 0:
continue
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
if isinstance(self.asr_model, BiCifParaformer):
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
us_peaks[i],
copy.copy(token),
vad_offset=begin_time)
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
else:
results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
def generate_hotwords_list(self, hotword_list_or_file):
# for None
if hotword_list_or_file is None:
hotword_list = None
# for local txt inputs
elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
logging.info("Attempting to parse hotwords from local txt...")
hotword_list = []
hotword_str_list = []
with codecs.open(hotword_list_or_file, 'r') as fin:
for line in fin.readlines():
hw = line.strip()
hotword_str_list.append(hw)
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Initialized hotword list from file: {}, hotword list: {}."
.format(hotword_list_or_file, hotword_str_list))
# for url, download and generate txt
elif hotword_list_or_file.startswith('http'):
logging.info("Attempting to parse hotwords from url...")
work_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(work_dir):
os.makedirs(work_dir)
text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
local_file = requests.get(hotword_list_or_file)
open(text_file_path, "wb").write(local_file.content)
hotword_list_or_file = text_file_path
hotword_list = []
hotword_str_list = []
with codecs.open(hotword_list_or_file, 'r') as fin:
for line in fin.readlines():
hw = line.strip()
hotword_str_list.append(hw)
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Initialized hotword list from file: {}, hotword list: {}."
.format(hotword_list_or_file, hotword_str_list))
# for text str input
elif not hotword_list_or_file.endswith('.txt'):
logging.info("Attempting to parse hotwords as str...")
hotword_list = []
hotword_str_list = []
for hw in hotword_list_or_file.strip().split():
hotword_str_list.append(hw)
hotword_list.append(self.converter.tokens2ids([i for i in hw]))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Hotword list: {}.".format(hotword_str_list))
else:
hotword_list = None
return hotword_list
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
vad_infer_config: Optional[str] = None,
vad_model_file: Optional[str] = None,
vad_cmvn_file: Optional[str] = None,
time_stamp_writer: bool = False,
punc_infer_config: Optional[str] = None,
punc_model_file: Optional[str] = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
time_stamp_writer=time_stamp_writer,
punc_infer_config=punc_infer_config,
punc_model_file=punc_model_file,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
vad_infer_config: Optional[str] = None,
vad_model_file: Optional[str] = None,
vad_cmvn_file: Optional[str] = None,
time_stamp_writer: bool = True,
punc_infer_config: Optional[str] = None,
punc_model_file: Optional[str] = None,
outputs_dict: Optional[bool] = True,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
else:
hotword_list_or_file = None
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2vadsegment
speech2vadsegment_kwargs = dict(
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
device=device,
dtype=dtype,
)
# logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
# 3. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
text2punc = None
if punc_model_file is not None:
text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
if output_dir is not None:
writer = DatadirWriter(output_dir)
ibest_writer = writer[f"1best_recog"]
ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
if 'hotword' in kwargs:
hotword_list_or_file = kwargs['hotword']
if speech2text.hotword_list is None:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=1,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
if param_dict is not None:
use_timestamp = param_dict.get('use_timestamp', True)
else:
use_timestamp = True
finish_count = 0
file_count = 1
lfr_factor = 6
# 7 .Start for-loop
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
writer = None
if output_path is not None:
writer = DatadirWriter(output_path)
ibest_writer = writer[f"1best_recog"]
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
vad_results = speech2vadsegment(**batch)
fbanks, vadsegments = vad_results[0], vad_results[1]
for i, segments in enumerate(vadsegments):
result_segments = [["", [], [], []]]
for j, segment_idx in enumerate(segments):
bed_idx, end_idx = int(segment_idx[0] / 10), int(segment_idx[1] / 10)
segment = fbanks[:, bed_idx:end_idx, :].to(device)
speech_lengths = torch.Tensor([end_idx - bed_idx]).int().to(device)
batch = {"speech": segment, "speech_lengths": speech_lengths, "begin_time": vadsegments[i][j][0],
"end_time": vadsegments[i][j][1]}
results = speech2text(**batch)
if len(results) < 1:
continue
result_cur = [results[0][:-2]]
if j == 0:
result_segments = result_cur
else:
result_segments = [
[result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
key = keys[0]
result = result_segments[0]
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
text_postprocessed = ""
time_stamp_postprocessed = ""
text_postprocessed_punc = postprocessed_result
if len(postprocessed_result) == 3:
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
else:
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
text_postprocessed_punc = text_postprocessed
punc_id_list = []
if len(word_lists) > 0 and text2punc is not None:
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
item['text_postprocessed'] = text_postprocessed
if time_stamp_postprocessed != "":
item['time_stamp'] = time_stamp_postprocessed
item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["vad"][key] = "{}".format(vadsegments)
ibest_writer["text"][key] = " ".join(word_lists)
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group.add_argument("--time_stamp_writer", type=str2bool, default=False)
group.add_argument(
"--frontend_conf",
default=None,
help="",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--vad_cmvn_file",
type=str,
help="vad, Global cmvn file",
)
group.add_argument(
"--punc_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--punc_model_file",
type=str,
help="VAD model parameter file",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,737 @@
#!/usr/bin/env python3
""" Inference class definition for Transducer models."""
from __future__ import annotations
import argparse
import logging
import math
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from packaging.version import parse as V
from typeguard import check_argument_types, check_return_type
from funasr_local.modules.beam_search.beam_search_transducer import (
BeamSearchTransducer,
Hypothesis,
)
from funasr_local.modules.nets_utils import TooShortUttError
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.tasks.asr import ASRTransducerTask
from funasr_local.tasks.lm import LMTask
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.types import str2bool, str2triple_str, str_or_none
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.models.frontend.wav_frontend import WavFrontend
class Speech2Text:
"""Speech2Text class for Transducer models.
Args:
asr_train_config: ASR model training config path.
asr_model_file: ASR model path.
beam_search_config: Beam search config path.
lm_train_config: Language Model training config path.
lm_file: Language Model config path.
token_type: Type of token units.
bpemodel: BPE model path.
device: Device to use for inference.
beam_size: Size of beam during search.
dtype: Data type.
lm_weight: Language model weight.
quantize_asr_model: Whether to apply dynamic quantization to ASR model.
quantize_modules: List of module names to apply dynamic quantization on.
quantize_dtype: Dynamic quantization data type.
nbest: Number of final hypothesis.
streaming: Whether to perform chunk-by-chunk inference.
chunk_size: Number of frames in chunk AFTER subsampling.
left_context: Number of frames in left context AFTER subsampling.
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
beam_search_config: Dict[str, Any] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
beam_size: int = 5,
dtype: str = "float32",
lm_weight: float = 1.0,
quantize_asr_model: bool = False,
quantize_modules: List[str] = None,
quantize_dtype: str = "qint8",
nbest: int = 1,
streaming: bool = False,
simu_streaming: bool = False,
chunk_size: int = 16,
left_context: int = 32,
right_context: int = 0,
display_partial_hypotheses: bool = False,
) -> None:
"""Construct a Speech2Text object."""
super().__init__()
assert check_argument_types()
asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
if quantize_asr_model:
if quantize_modules is not None:
if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
raise ValueError(
"Only 'Linear' and 'LSTM' modules are currently supported"
" by PyTorch and in --quantize_modules"
)
q_config = set([getattr(torch.nn, q) for q in quantize_modules])
else:
q_config = {torch.nn.Linear}
if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with torch"
" version < 1.5.0. Switching to qint8 dtype instead."
)
q_dtype = getattr(torch, quantize_dtype)
asr_model = torch.quantization.quantize_dynamic(
asr_model, q_config, dtype=q_dtype
).eval()
else:
asr_model.to(dtype=getattr(torch, dtype)).eval()
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
lm_scorer = lm.lm
else:
lm_scorer = None
# 4. Build BeamSearch object
if beam_search_config is None:
beam_search_config = {}
beam_search = BeamSearchTransducer(
asr_model.decoder,
asr_model.joint_network,
beam_size,
lm=lm_scorer,
lm_weight=lm_weight,
nbest=nbest,
**beam_search_config,
)
token_list = asr_model.token_list
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.device = device
self.dtype = dtype
self.nbest = nbest
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.streaming = streaming
self.simu_streaming = simu_streaming
self.chunk_size = max(chunk_size, 0)
self.left_context = left_context
self.right_context = max(right_context, 0)
if not streaming or chunk_size == 0:
self.streaming = False
self.asr_model.encoder.dynamic_chunk_training = False
if not simu_streaming or chunk_size == 0:
self.simu_streaming = False
self.asr_model.encoder.dynamic_chunk_training = False
self.frontend = frontend
self.window_size = self.chunk_size + self.right_context
self._ctx = self.asr_model.encoder.get_encoder_input_size(
self.window_size
)
#self.last_chunk_length = (
# self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
#) * self.hop_length
self.last_chunk_length = (
self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
)
self.reset_inference_cache()
def reset_inference_cache(self) -> None:
"""Reset Speech2Text parameters."""
self.frontend_cache = None
self.asr_model.encoder.reset_streaming_cache(
self.left_context, device=self.device
)
self.beam_search.reset_inference_cache()
self.num_processed_frames = torch.tensor([[0]], device=self.device)
@torch.no_grad()
def streaming_decode(
self,
speech: Union[torch.Tensor, np.ndarray],
is_final: bool = True,
) -> List[Hypothesis]:
"""Speech2Text streaming call.
Args:
speech: Chunk of speech data. (S)
is_final: Whether speech corresponds to the final chunk of data.
Returns:
nbest_hypothesis: N-best hypothesis.
"""
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if is_final:
if self.streaming and speech.size(0) < self.last_chunk_length:
pad = torch.zeros(
self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
)
speech = torch.cat([speech, pad], dim=0) #feats, feats_length = self.apply_frontend(speech, is_final=is_final)
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
enc_out = self.asr_model.encoder.chunk_forward(
feats,
feats_lengths,
self.num_processed_frames,
chunk_size=self.chunk_size,
left_context=self.left_context,
right_context=self.right_context,
)
nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
self.num_processed_frames += self.chunk_size
if is_final:
self.reset_inference_cache()
return nbest_hyps
@torch.no_grad()
def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
"""Speech2Text call.
Args:
speech: Speech data. (S)
Returns:
nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context)
nbest_hyps = self.beam_search(enc_out[0])
return nbest_hyps
@torch.no_grad()
def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
"""Speech2Text call.
Args:
speech: Speech data. (S)
Returns:
nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
nbest_hyps = self.beam_search(enc_out[0])
return nbest_hyps
def hypotheses_to_results(self, nbest_hyps: List[Hypothesis]) -> List[Any]:
"""Build partial or final results from the hypotheses.
Args:
nbest_hyps: N-best hypothesis.
Returns:
results: Results containing different representation for the hypothesis.
"""
results = []
for hyp in nbest_hyps:
token_int = list(filter(lambda x: x != 0, hyp.yseq))
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
) -> Speech2Text:
"""Build Speech2Text instance from the pretrained model.
Args:
model_tag: Model tag of the pretrained models.
Return:
: Speech2Text instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Text(**kwargs)
def inference(
output_dir: str,
batch_size: int,
dtype: str,
beam_size: int,
ngpu: int,
seed: int,
lm_weight: float,
nbest: int,
num_workers: int,
log_level: Union[int, str],
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str],
beam_search_config: Optional[dict],
lm_train_config: Optional[str],
lm_file: Optional[str],
model_tag: Optional[str],
token_type: Optional[str],
bpemodel: Optional[str],
key_file: Optional[str],
allow_variable_data_keys: bool,
quantize_asr_model: Optional[bool],
quantize_modules: Optional[List[str]],
quantize_dtype: Optional[str],
streaming: Optional[bool],
simu_streaming: Optional[bool],
chunk_size: Optional[int],
left_context: Optional[int],
right_context: Optional[int],
display_partial_hypotheses: bool,
**kwargs,
) -> None:
"""Transducer model inference.
Args:
output_dir: Output directory path.
batch_size: Batch decoding size.
dtype: Data type.
beam_size: Beam size.
ngpu: Number of GPUs.
seed: Random number generator seed.
lm_weight: Weight of language model.
nbest: Number of final hypothesis.
num_workers: Number of workers.
log_level: Level of verbose for logs.
data_path_and_name_and_type:
asr_train_config: ASR model training config path.
asr_model_file: ASR model path.
beam_search_config: Beam search config path.
lm_train_config: Language Model training config path.
lm_file: Language Model path.
model_tag: Model tag.
token_type: Type of token units.
bpemodel: BPE model path.
key_file: File key.
allow_variable_data_keys: Whether to allow variable data keys.
quantize_asr_model: Whether to apply dynamic quantization to ASR model.
quantize_modules: List of module names to apply dynamic quantization on.
quantize_dtype: Dynamic quantization data type.
streaming: Whether to perform chunk-by-chunk inference.
chunk_size: Number of frames in chunk AFTER subsampling.
left_context: Number of frames in left context AFTER subsampling.
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1:
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
beam_search_config=beam_search_config,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
dtype=dtype,
beam_size=beam_size,
lm_weight=lm_weight,
nbest=nbest,
quantize_asr_model=quantize_asr_model,
quantize_modules=quantize_modules,
quantize_dtype=quantize_dtype,
streaming=streaming,
simu_streaming=simu_streaming,
chunk_size=chunk_size,
left_context=left_context,
right_context=right_context,
)
speech2text = Speech2Text.from_pretrained(
model_tag=model_tag,
**speech2text_kwargs,
)
# 3. Build data-iterator
loader = ASRTransducerTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTransducerTask.build_preprocess_fn(
speech2text.asr_train_args, False
),
collate_fn=ASRTransducerTask.build_collate_fn(
speech2text.asr_train_args, False
),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
# 4 .Start for-loop
with DatadirWriter(output_dir) as writer:
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
assert len(batch.keys()) == 1
try:
if speech2text.streaming:
speech = batch["speech"]
_steps = len(speech) // speech2text._ctx
_end = 0
for i in range(_steps):
_end = (i + 1) * speech2text._ctx
speech2text.streaming_decode(
speech[i * speech2text._ctx : _end], is_final=False
)
final_hyps = speech2text.streaming_decode(
speech[_end : len(speech)], is_final=True
)
elif speech2text.simu_streaming:
final_hyps = speech2text.simu_streaming_decode(**batch)
else:
final_hyps = speech2text(**batch)
results = speech2text.hypotheses_to_results(final_hyps)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
results = [[" ", ["<space>"], [2], hyp]] * nbest
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
ibest_writer = writer[f"{n}best_recog"]
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
ibest_writer["text"][key] = text
def get_parser():
"""Get Transducer model inference parser."""
parser = config_argparse.ArgumentParser(
description="ASR Transducer Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=5, help="Beam size")
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument(
"--beam_search_config",
default={},
help="The keyword arguments for transducer beam search.",
)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
group = parser.add_argument_group("Dynamic quantization related")
parser.add_argument(
"--quantize_asr_model",
type=bool,
default=False,
help="Apply dynamic quantization to ASR model.",
)
parser.add_argument(
"--quantize_modules",
nargs="*",
default=None,
help="""Module names to apply dynamic quantization on.
The module names are provided as a list, where each name is separated
by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
Each specified name should be an attribute of 'torch.nn', e.g.:
torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
)
parser.add_argument(
"--quantize_dtype",
type=str,
default="qint8",
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
)
group = parser.add_argument_group("Streaming related")
parser.add_argument(
"--streaming",
type=bool,
default=False,
help="Whether to perform chunk-by-chunk inference.",
)
parser.add_argument(
"--simu_streaming",
type=bool,
default=False,
help="Whether to simulate chunk-by-chunk inference.",
)
parser.add_argument(
"--chunk_size",
type=int,
default=16,
help="Number of frames in chunk AFTER subsampling.",
)
parser.add_argument(
"--left_context",
type=int,
default=32,
help="Number of frames in left context of the chunk AFTER subsampling.",
)
parser.add_argument(
"--right_context",
type=int,
default=0,
help="Number of frames in right context of the chunk AFTER subsampling.",
)
parser.add_argument(
"--display_partial_hypotheses",
type=bool,
default=False,
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,694 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
from pathlib import Path
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.modules.beam_search.beam_search import BeamSearchScama as BeamSearch
from funasr_local.modules.beam_search.beam_search import Hypothesis
from funasr_local.modules.scorers.ctc import CTCPrefixScorer
from funasr_local.modules.scorers.length_bonus import LengthBonus
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.tasks.asr import ASRTaskUniASR as ASRTask
from funasr_local.tasks.lm import LMTask
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils import asr_utils, wav_utils, postprocess_utils
from funasr_local.models.frontend.wav_frontend import WavFrontend
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
if decoding_mode == "model1":
decoder = asr_model.decoder
else:
decoder = asr_model.decoder2
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
# logging.info(f"Beam_search: {beam_search}")
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
self.token_num_relax = token_num_relax
self.decoding_ind = decoding_ind
self.decoding_mode = decoding_mode
self.frontend = frontend
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
feats_raw = feats.clone().to(self.device)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
_, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
if self.decoding_mode == "model1":
predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
else:
enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
scama_mask = predictor_outs[4]
pre_token_length = predictor_outs[1]
pre_acoustic_embeds = predictor_outs[0]
maxlen = pre_token_length.sum().item() + self.token_num_relax
minlen = max(0, pre_token_length.sum().item() - self.token_num_relax)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio,
minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen),
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = list(filter(lambda x: x != "<gbg>", token))
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
ngram_file: Optional[str] = None,
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
ngram_file=ngram_file,
nbest=nbest,
num_workers=num_workers,
token_num_relax=token_num_relax,
decoding_ind=decoding_ind,
decoding_mode=decoding_mode,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
ngram_file: Optional[str] = None,
cmvn_file: Optional[str] = None,
# raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
if param_dict is not None and "decoding_model" in param_dict:
if param_dict["decoding_model"] == "fast":
decoding_ind = 0
decoding_mode = "model1"
elif param_dict["decoding_model"] == "normal":
decoding_ind = 0
decoding_mode = "model2"
elif param_dict["decoding_model"] == "offline":
decoding_ind = 1
decoding_mode = "model2"
else:
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
ngram_file=ngram_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
token_num_relax=token_num_relax,
decoding_ind=decoding_ind,
decoding_mode=decoding_mode,
)
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
#batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
logging.info(f"Utterance: {key}")
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = " ".join(word_lists)
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
group.add_argument("--token_num_relax", type=int, default=1, help="")
group.add_argument("--decoding_ind", type=int, default=0, help="")
group.add_argument("--decoding_mode", type=str, default="model1", help="")
group.add_argument(
"--ctc_weight2",
type=float,
default=0.0,
help="CTC weight in joint decoding",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,695 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
from pathlib import Path
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.modules.beam_search.beam_search import BeamSearchScama as BeamSearch
from funasr_local.modules.beam_search.beam_search import Hypothesis
from funasr_local.modules.scorers.ctc import CTCPrefixScorer
from funasr_local.modules.scorers.length_bonus import LengthBonus
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.tasks.asr import ASRTaskUniASR as ASRTask
from funasr_local.tasks.lm import LMTask
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils import asr_utils, wav_utils, postprocess_utils
from funasr_local.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
if decoding_mode == "model1":
decoder = asr_model.decoder
else:
decoder = asr_model.decoder2
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
# logging.info(f"Beam_search: {beam_search}")
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
self.token_num_relax = token_num_relax
self.decoding_ind = decoding_ind
self.decoding_mode = decoding_mode
self.frontend = frontend
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
feats_raw = feats.clone().to(self.device)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
_, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
if self.decoding_mode == "model1":
predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
else:
enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
scama_mask = predictor_outs[4]
pre_token_length = predictor_outs[1]
pre_acoustic_embeds = predictor_outs[0]
maxlen = pre_token_length.sum().item() + self.token_num_relax
minlen = max(0, pre_token_length.sum().item() - self.token_num_relax)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio,
minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen),
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = list(filter(lambda x: x != "<gbg>", token))
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
ngram_file: Optional[str] = None,
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
ngram_file=ngram_file,
nbest=nbest,
num_workers=num_workers,
token_num_relax=token_num_relax,
decoding_ind=decoding_ind,
decoding_mode=decoding_mode,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
ngram_file: Optional[str] = None,
cmvn_file: Optional[str] = None,
# raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
if param_dict is not None and "decoding_model" in param_dict:
if param_dict["decoding_model"] == "fast":
decoding_ind = 0
decoding_mode = "model1"
elif param_dict["decoding_model"] == "normal":
decoding_ind = 0
decoding_mode = "model2"
elif param_dict["decoding_model"] == "offline":
decoding_ind = 1
decoding_mode = "model2"
else:
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
ngram_file=ngram_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
token_num_relax=token_num_relax,
decoding_ind=decoding_ind,
decoding_mode=decoding_mode,
)
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
#batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
logging.info(f"Utterance: {key}")
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = " ".join(word_lists)
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
group.add_argument("--token_num_relax", type=int, default=1, help="")
group.add_argument("--decoding_ind", type=int, default=0, help="")
group.add_argument("--decoding_mode", type=str, default="model1", help="")
group.add_argument(
"--ctc_weight2",
type=float,
default=0.0,
help="CTC weight in joint decoding",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
import os
from funasr_local.tasks.asr import ASRTask
# for ASR Training
def parse_args():
parser = ASRTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for ASR Training
ASRTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
import os
from funasr_local.tasks.asr import ASRTaskParaformer as ASRTask
# for ASR Training
def parse_args():
parser = ASRTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for ASR Training
ASRTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
import os
from funasr_local.tasks.asr import ASRTransducerTask
# for ASR Training
def parse_args():
parser = ASRTransducerTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for ASR Training
ASRTransducerTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
import os
from funasr_local.tasks.asr import ASRTaskUniASR
# for ASR Training
def parse_args():
parser = ASRTaskUniASR.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for ASR Training
ASRTaskUniASR.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@@ -0,0 +1,145 @@
import os
import yaml
def update_dct(fin_configs, root):
if root == {}:
return {}
for root_key, root_value in root.items():
if not isinstance(root[root_key], dict):
fin_configs[root_key] = root[root_key]
else:
if root_key in fin_configs.keys():
result = update_dct(fin_configs[root_key], root[root_key])
fin_configs[root_key] = result
else:
fin_configs[root_key] = root[root_key]
return fin_configs
def parse_args(mode):
if mode == "asr":
from funasr_local.tasks.asr import ASRTask as ASRTask
elif mode == "paraformer":
from funasr_local.tasks.asr import ASRTaskParaformer as ASRTask
elif mode == "paraformer_vad_punc":
from funasr_local.tasks.asr import ASRTaskParaformer as ASRTask
elif mode == "uniasr":
from funasr_local.tasks.asr import ASRTaskUniASR as ASRTask
elif mode == "mfcca":
from funasr_local.tasks.asr import ASRTaskMFCCA as ASRTask
elif mode == "tp":
from funasr_local.tasks.asr import ASRTaskAligner as ASRTask
else:
raise ValueError("Unknown mode: {}".format(mode))
parser = ASRTask.get_parser()
args = parser.parse_args()
return args, ASRTask
def build_trainer(modelscope_dict,
data_dir,
output_dir,
train_set="train",
dev_set="validation",
distributed=False,
dataset_type="small",
batch_bins=None,
max_epoch=None,
optim=None,
lr=None,
scheduler=None,
scheduler_conf=None,
specaug=None,
specaug_conf=None,
param_dict=None,
**kwargs):
mode = modelscope_dict['mode']
args, ASRTask = parse_args(mode=mode)
# ddp related
if args.local_rank is not None:
distributed = True
else:
distributed = False
args.local_rank = args.local_rank if args.local_rank is not None else 0
local_rank = args.local_rank
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
config = modelscope_dict['am_model_config']
finetune_config = modelscope_dict['finetune_config']
init_param = modelscope_dict['init_model']
cmvn_file = modelscope_dict['cmvn_file']
seg_dict_file = modelscope_dict['seg_dict']
# overwrite parameters
with open(config) as f:
configs = yaml.safe_load(f)
with open(finetune_config) as f:
finetune_configs = yaml.safe_load(f)
# set data_types
if dataset_type == "large":
finetune_configs["dataset_conf"]["data_types"] = "sound,text"
finetune_configs = update_dct(configs, finetune_configs)
for key, value in finetune_configs.items():
if hasattr(args, key):
setattr(args, key, value)
# prepare data
args.dataset_type = dataset_type
if args.dataset_type == "small":
args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
["{}/{}/text".format(data_dir, train_set), "text", "text"]]
args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
elif args.dataset_type == "large":
args.train_data_file = None
args.valid_data_file = None
else:
raise ValueError(f"Not supported dataset_type={args.dataset_type}")
args.init_param = [init_param]
args.cmvn_file = cmvn_file
if os.path.exists(seg_dict_file):
args.seg_dict_file = seg_dict_file
else:
args.seg_dict_file = None
args.data_dir = data_dir
args.train_set = train_set
args.dev_set = dev_set
args.output_dir = output_dir
args.gpu_id = args.local_rank
args.config = finetune_config
if optim is not None:
args.optim = optim
if lr is not None:
args.optim_conf["lr"] = lr
if scheduler is not None:
args.scheduler = scheduler
if scheduler_conf is not None:
args.scheduler_conf = scheduler_conf
if specaug is not None:
args.specaug = specaug
if specaug_conf is not None:
args.specaug_conf = specaug_conf
if max_epoch is not None:
args.max_epoch = max_epoch
if batch_bins is not None:
if args.dataset_type == "small":
args.batch_bins = batch_bins
elif args.dataset_type == "large":
args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
else:
raise ValueError(f"Not supported dataset_type={args.dataset_type}")
if args.normalize in ["null", "none", "None"]:
args.normalize = None
if args.patience in ["null", "none", "None"]:
args.patience = None
args.local_rank = local_rank
args.distributed = distributed
ASRTask.finetune_args = args
return ASRTask

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python3
import os
from funasr_local.tasks.data2vec import Data2VecTask
def parse_args():
parser = Data2VecTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for data2vec Training
Data2VecTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@@ -0,0 +1,185 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--njob",
type=int,
default=1,
help="The number of jobs for each gpu",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--diar_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--diar_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global CMVN file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("The inference configuration related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument(
"--diar_smooth_size",
type=int,
default=121,
help="The smoothing size for post-processing"
)
return parser
def inference_launch(mode, **kwargs):
if mode == "sond":
from funasr_local.bin.sond_inference import inference_modelscope
return inference_modelscope(mode=mode, **kwargs)
elif mode == "sond_demo":
from funasr_local.bin.sond_inference import inference_modelscope
param_dict = {
"extract_profile": True,
"sv_train_config": "sv.yaml",
"sv_model_file": "sv.pb",
}
if "param_dict" in kwargs and kwargs["param_dict"] is not None:
for key in param_dict:
if key not in kwargs["param_dict"]:
kwargs["param_dict"][key] = param_dict[key]
else:
kwargs["param_dict"] = param_dict
return inference_modelscope(mode=mode, **kwargs)
elif mode == "eend-ola":
from funasr_local.bin.eend_ola_inference import inference_modelscope
return inference_modelscope(mode=mode, **kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
parser.add_argument(
"--mode",
type=str,
default="sond",
help="The decoding mode",
)
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
inference_launch(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
import os
from funasr_local.tasks.diar import DiarTask
# for ASR Training
def parse_args():
parser = DiarTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for ASR Training
DiarTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@@ -0,0 +1,429 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
from scipy.signal import medfilt
from typeguard import check_argument_types
from funasr_local.models.frontend.wav_frontend import WavFrontendMel23
from funasr_local.tasks.diar import EENDOLADiarTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
class Speech2Diarization:
"""Speech2Diarlization class
Examples:
>>> import soundfile
>>> import numpy as np
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...}
"""
def __init__(
self,
diar_train_config: Union[Path, str] = None,
diar_model_file: Union[Path, str] = None,
device: str = "cpu",
dtype: str = "float32",
):
assert check_argument_types()
# 1. Build Diarization model
diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
config_file=diar_train_config,
model_file=diar_model_file,
device=device
)
frontend = None
if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
# set up seed for eda
np.random.seed(diar_train_args.seed)
torch.manual_seed(diar_train_args.seed)
torch.cuda.manual_seed(diar_train_args.seed)
os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
logging.info("diar_model: {}".format(diar_model))
logging.info("diar_train_args: {}".format(diar_train_args))
diar_model.to(dtype=getattr(torch, dtype)).eval()
self.diar_model = diar_model
self.diar_train_args = diar_train_args
self.device = device
self.dtype = dtype
self.frontend = frontend
@torch.no_grad()
def __call__(
self,
speech: Union[torch.Tensor, np.ndarray],
speech_lengths: Union[torch.Tensor, np.ndarray] = None
):
"""Inference
Args:
speech: Input speech data
Returns:
diarization results
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.diar_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
batch = {"speech": feats, "speech_lengths": feats_len}
batch = to_device(batch, device=self.device)
results = self.diar_model.estimate_sequential(**batch)
return results
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
):
"""Build Speech2Diarization instance from the pretrained model.
Args:
model_tag (Optional[str]): Model tag of the pretrained models.
Currently, the tags of espnet_model_zoo are supported.
Returns:
Speech2Diarization: Speech2Diarization instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Diarization(**kwargs)
def inference_modelscope(
diar_train_config: str,
diar_model_file: str,
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
ngpu: int = 1,
num_workers: int = 0,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
param_dict: Optional[dict] = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("param_dict: {}".format(param_dict))
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Build speech2diar
speech2diar_kwargs = dict(
diar_train_config=diar_train_config,
diar_model_file=diar_model_file,
device=device,
dtype=dtype,
)
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
speech2diar = Speech2Diarization.from_pretrained(
model_tag=model_tag,
**speech2diar_kwargs,
)
speech2diar.diar_model.eval()
def output_results_str(results: dict, uttid: str):
rst = []
mid = uttid.rsplit("-", 1)[0]
for key in results:
results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
for spk, segs in results.items():
rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
return "\n".join(rst)
def _forward(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
output_dir_v2: Optional[str] = None,
param_dict: Optional[dict] = None,
):
# 2. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
loader = EENDOLADiarTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
# 3. Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
os.makedirs(output_path, exist_ok=True)
output_writer = open("{}/result.txt".format(output_path), "w")
result_list = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
results = speech2diar(**batch)
# post process
a = results[0][0].cpu().numpy()
a = medfilt(a, (11, 1))
rst = []
for spkid, frames in enumerate(a.T):
frames = np.pad(frames, (1, 1), 'constant')
changes, = np.where(np.diff(frames, axis=0) != 0)
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
for s, e in zip(changes[::2], changes[1::2]):
st = s / 10.
dur = (e - s) / 10.
rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
# Only supporting batch_size==1
value = "\n".join(rst)
item = {"key": keys[0], "value": value}
result_list.append(item)
if output_path is not None:
output_writer.write(value)
output_writer.flush()
if output_path is not None:
output_writer.close()
return result_list
return _forward
def inference(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
diar_train_config: Optional[str],
diar_model_file: Optional[str],
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
ngpu: int = 0,
seed: int = 0,
num_workers: int = 1,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
smooth_size: int = 83,
dur_threshold: int = 10,
out_format: str = "vad",
**kwargs,
):
inference_pipeline = inference_modelscope(
diar_train_config=diar_train_config,
diar_model_file=diar_model_file,
output_dir=output_dir,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
seed=seed,
num_workers=num_workers,
log_level=log_level,
key_file=key_file,
model_tag=model_tag,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
smooth_size=smooth_size,
dur_threshold=dur_threshold,
out_format=out_format,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker verification/x-vector extraction",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--diar_train_config",
type=str,
help="diarization training configuration",
)
group.add_argument(
"--diar_model_file",
type=str,
help="diarization model parameter file",
)
group.add_argument(
"--dur_threshold",
type=int,
default=10,
help="The threshold for short segments in number frames"
)
parser.add_argument(
"--smooth_size",
type=int,
default=83,
help="The smoothing window length in number frames"
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
parser.add_argument("--streaming", type=str2bool, default=False)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
logging.info("args: {}".format(kwargs))
if args.output_dir is None:
jobid, n_gpu = 1, 1
gpuid = args.gpuid_list.split(",")[jobid - 1]
else:
jobid = int(args.output_dir.split(".")[-1])
n_gpu = len(args.gpuid_list.split(","))
gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
results_list = inference(**kwargs)
for results in results_list:
print("{} {}".format(results["key"], results["value"]))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,211 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.tasks.lm import LMTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.forward_adaptor import ForwardAdaptor
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.types import float_or_none
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
def calc_perplexity(
output_dir: str,
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
log_base: Optional[float],
allow_variable_data_keys: bool,
):
assert check_argument_types()
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1:
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build LM
model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
# Wrape model to make model.nll() data-parallel
wrapped_model = ForwardAdaptor(model, "nll")
wrapped_model.to(dtype=getattr(torch, dtype)).eval()
logging.info(f"Model:\n{model}")
# 3. Build data-iterator
loader = LMTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=LMTask.build_preprocess_fn(train_args, False),
collate_fn=LMTask.build_collate_fn(train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
# 4. Start for-loop
with DatadirWriter(output_dir) as writer:
total_nll = 0.0
total_ntokens = 0
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
with torch.no_grad():
batch = to_device(batch, device)
if ngpu <= 1:
# NOTE(kamo): data_parallel also should work with ngpu=1,
# but for debuggability it's better to keep this block.
nll, lengths = wrapped_model(**batch)
else:
nll, lengths = data_parallel(
wrapped_model, (), range(ngpu), module_kwargs=batch
)
assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
# nll: (B, L) -> (B,)
nll = nll.detach().cpu().numpy().sum(1)
# lengths: (B,)
lengths = lengths.detach().cpu().numpy()
total_nll += nll.sum()
total_ntokens += lengths.sum()
for key, _nll, ntoken in zip(keys, nll, lengths):
if log_base is None:
utt_ppl = np.exp(_nll / ntoken)
else:
utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
# Write PPL of each utts for debugging or analysis
writer["utt2nll"][key] = str(-_nll)
writer["utt2ppl"][key] = str(utt_ppl)
writer["utt2ntokens"][key] = str(ntoken)
if log_base is None:
ppl = np.exp(total_nll / total_ntokens)
else:
ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
with (Path(output_dir) / "ppl").open("w", encoding="utf-8") as f:
f.write(f"{ppl}\n")
with (Path(output_dir) / "base").open("w", encoding="utf-8") as f:
if log_base is None:
_log_base = np.e
else:
_log_base = log_base
f.write(f"{_log_base}\n")
logging.info(f"PPL={ppl}")
def get_parser():
parser = config_argparse.ArgumentParser(
description="Calc perplexity",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
parser.add_argument(
"--log_base",
type=float_or_none,
default=None,
help="The base of logarithm for Perplexity. "
"If None, napier's constant is used.",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
calc_perplexity(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,406 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
import sys
import os
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
from funasr_local.tasks.lm import LMTask
from funasr_local.datasets.preprocessor import LMPreprocessor
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.forward_adaptor import ForwardAdaptor
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.types import float_or_none
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
def inference(
output_dir: str,
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
train_config: Optional[str],
model_file: Optional[str],
log_base: Optional[float],
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
split_with_space: Optional[bool] = False,
seg_dict_file: Optional[str] = None,
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[List[Any], bytes, str] = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
output_dir=output_dir,
raw_inputs=raw_inputs,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
seed=seed,
num_workers=num_workers,
log_level=log_level,
key_file=key_file,
train_config=train_config,
model_file=model_file,
log_base = log_base,
allow_variable_data_keys = allow_variable_data_keys,
split_with_space=split_with_space,
seg_dict_file=seg_dict_file,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
log_base: Optional[float] = 10,
allow_variable_data_keys: bool = False,
split_with_space: Optional[bool] = False,
seg_dict_file: Optional[str] = None,
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build Model
model, train_args = LMTask.build_model_from_file(
train_config, model_file, device)
wrapped_model = ForwardAdaptor(model, "nll")
wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
logging.info(f"Model:\n{model}")
preprocessor = LMPreprocessor(
train=False,
token_type=train_args.token_type,
token_list=train_args.token_list,
bpemodel=train_args.bpemodel,
text_cleaner=train_args.cleaner,
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
split_with_space=split_with_space,
seg_dict_file=seg_dict_file
)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
param_dict: dict = None,
):
results = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
if raw_inputs != None:
line = raw_inputs.strip()
key = "lm demo"
if line=="":
item = {'key': key, 'value': ""}
results.append(item)
return results
batch = {}
batch['text'] = line
if preprocessor != None:
batch = preprocessor(key, batch)
# Force data-precision
for name in batch:
value = batch[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f"All values must be converted to np.ndarray object "
f'by preprocessing, but "{name}" is still {type(value)}.'
)
# Cast to desired type
if value.dtype.kind == "f":
value = value.astype("float32")
elif value.dtype.kind == "i":
value = value.astype("long")
else:
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
batch[name] = value
batch["text_lengths"] = torch.from_numpy(
np.array([len(batch["text"])], dtype='int32'))
batch["text"] = np.expand_dims(batch["text"], axis=0)
with torch.no_grad():
batch = to_device(batch, device)
if ngpu <= 1:
nll, lengths = wrapped_model(**batch)
else:
nll, lengths = data_parallel(
wrapped_model, (), range(ngpu), module_kwargs=batch
)
## compute ppl
ppl_out_batch = ""
ids2tokens = preprocessor.token_id_converter.ids2tokens
for sent_ids, sent_nll in zip(batch['text'], nll):
pre_word = "<s>"
cur_word = None
sent_lst = ids2tokens(sent_ids) + ['</s>']
ppl_out = " ".join(sent_lst) + "\n"
for word, word_nll in zip(sent_lst, sent_nll):
cur_word = word
word_nll = -word_nll.cpu()
if log_base is None:
word_prob = np.exp(word_nll)
else:
word_prob = log_base ** (word_nll / np.log(log_base))
ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
cur=cur_word,
pre=pre_word,
prob=round(word_prob.item(), 8),
word_nll=round(word_nll.item(), 8)
)
pre_word = cur_word
sent_nll_mean = sent_nll.mean().cpu().numpy()
sent_nll_sum = sent_nll.sum().cpu().numpy()
if log_base is None:
sent_ppl = np.exp(sent_nll_mean)
else:
sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
sent_nll=round(-sent_nll_sum.item(), 4),
sent_ppl=round(sent_ppl.item(), 4)
)
ppl_out_batch += ppl_out
item = {'key': key, 'value': ppl_out}
if writer is not None:
writer["ppl"][key+":\n"] = ppl_out
results.append(item)
return results
# 3. Build data-iterator
loader = LMTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=preprocessor,
collate_fn=LMTask.build_collate_fn(train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
# 4. Start for-loop
total_nll = 0.0
total_ntokens = 0
ppl_out_all = ""
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
ppl_out_batch = ""
with torch.no_grad():
batch = to_device(batch, device)
if ngpu <= 1:
# NOTE(kamo): data_parallel also should work with ngpu=1,
# but for debuggability it's better to keep this block.
nll, lengths = wrapped_model(**batch)
else:
nll, lengths = data_parallel(
wrapped_model, (), range(ngpu), module_kwargs=batch
)
## print ppl
ids2tokens = preprocessor.token_id_converter.ids2tokens
for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
pre_word = "<s>"
cur_word = None
sent_lst = ids2tokens(sent_ids) + ['</s>']
ppl_out = " ".join(sent_lst) + "\n"
for word, word_nll in zip(sent_lst, sent_nll):
cur_word = word
word_nll = -word_nll.cpu()
if log_base is None:
word_prob = np.exp(word_nll)
else:
word_prob = log_base ** (word_nll / np.log(log_base))
ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
cur=cur_word,
pre=pre_word,
prob=round(word_prob.item(), 8),
word_nll=round(word_nll.item(), 8)
)
pre_word = cur_word
sent_nll_mean = sent_nll.mean().cpu().numpy()
sent_nll_sum = sent_nll.sum().cpu().numpy()
if log_base is None:
sent_ppl = np.exp(sent_nll_mean)
else:
sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
sent_nll=round(-sent_nll_sum.item(), 4),
sent_ppl=round(sent_ppl.item(), 4)
)
ppl_out_batch += ppl_out
utt2nll = round(-sent_nll_sum.item(), 5)
item = {'key': key, 'value': ppl_out}
if writer is not None:
writer["ppl"][key+":\n"] = ppl_out
writer["utt2nll"][key] = str(utt2nll)
results.append(item)
ppl_out_all += ppl_out_batch
assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
# nll: (B, L) -> (B,)
nll = nll.detach().cpu().numpy().sum(1)
# lengths: (B,)
lengths = lengths.detach().cpu().numpy()
total_nll += nll.sum()
total_ntokens += lengths.sum()
if log_base is None:
ppl = np.exp(total_nll / total_ntokens)
else:
ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
total_nll=round(-total_nll.item(), 4),
total_ppl=round(ppl.item(), 4)
)
item = {'key': 'AVG PPL', 'value': avg_ppl}
ppl_out_all += avg_ppl
if writer is not None:
writer["ppl"]["AVG PPL : "] = avg_ppl
results.append(item)
return results
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="Calc perplexity",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
parser.add_argument(
"--log_base",
type=float_or_none,
default=10,
help="The base of logarithm for Perplexity. "
"If None, napier's constant is used.",
required=False
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
action="append",
required=False
)
group.add_argument(
"--raw_inputs",
type=str,
required=False
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group.add_argument("--split_with_space", type=str2bool, default=False)
group.add_argument("--seg_dict_file", type=str_or_none)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,130 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils.types import float_or_none
def get_parser():
parser = config_argparse.ArgumentParser(
description="Calc perplexity",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--gpuid_list", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--njob", type=int, default=1, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
parser.add_argument(
"--log_base",
type=float_or_none,
default=10,
help="The base of logarithm for Perplexity. "
"If None, napier's constant is used.",
required=False
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
action="append",
required=False
)
group.add_argument(
"--raw_inputs",
type=str,
required=False
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group.add_argument("--split_with_space", type=str2bool, default=False)
group.add_argument("--seg_dict_file", type=str_or_none)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
group.add_argument("--mode", type=str, default="lm")
return parser
def inference_launch(mode, **kwargs):
if mode == "transformer":
from funasr_local.bin.lm_inference import inference_modelscope
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
kwargs.pop("gpuid_list", None)
kwargs.pop("njob", None)
results = inference_launch(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
import os
from funasr_local.tasks.lm import LMTask
# for LM Training
def parse_args():
parser = LMTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for LM Training
LMTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small" and args.ngpu != 0:
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python3
import argparse
import logging
import os
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="decoding configs",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--model_name",
type=str,
default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model name in modelscope")
parser.add_argument("--model_revision",
type=str,
default="v1.0.4",
help="model revision in modelscope")
parser.add_argument("--local_model_path",
type=str,
default=None,
help="local model path, usually for fine-tuning")
parser.add_argument("--wav_list",
type=str,
help="input wav list")
parser.add_argument("--output_file",
type=str,
help="saving decoding results")
parser.add_argument(
"--njob",
type=int,
default=1,
help="The number of jobs for each gpu",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
args = parser.parse_args()
# set logging messages
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(args))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_file.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
if args.local_model_path is None:
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/{}".format(args.model_name),
model_revision=args.model_revision)
else:
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model=args.local_model_path)
with open(args.wav_list, 'r') as f_wav:
wav_lines = f_wav.readlines()
with open(args.output_file, "w") as f_out:
for line in wav_lines:
wav_id, wav_path = line.strip().split()
logging.info("decoding, utt_id: ['{}']".format(wav_id))
rec_result = inference_pipeline(audio_in=wav_path)
if 'text' in rec_result:
text = rec_result["text"]
else:
text = ''
f_out.write(wav_id + " " + text + "\n")
logging.info("best hypo: {} \n".format(text))

View File

@@ -0,0 +1,112 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils.types import float_or_none
def get_parser():
parser = config_argparse.ArgumentParser(
description="Punctuation inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--gpuid_list", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--njob", type=int, default=1, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group = parser.add_argument_group("Input data related")
group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
group.add_argument("--raw_inputs", type=str, required=False)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--cache", type=list, required=False)
group.add_argument("--param_dict", type=dict, required=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
group.add_argument("--mode", type=str, default="punc")
return parser
def inference_launch(mode, **kwargs):
if mode == "punc":
from funasr_local.bin.punctuation_infer import inference_modelscope
return inference_modelscope(**kwargs)
if mode == "punc_VadRealtime":
from funasr_local.bin.punctuation_infer_vadrealtime import inference_modelscope
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
kwargs.pop("gpuid_list", None)
kwargs.pop("njob", None)
results = inference_launch(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python3
import os
from funasr_local.tasks.punctuation import PunctuationTask
def parse_args():
parser = PunctuationTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
parser.add_argument(
"--punc_list",
type=str,
default=None,
help="Punctuation list",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
"""
punc training.
"""
PunctuationTask.main(args=args, cmd=cmd)
if __name__ == "__main__":
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
main(args=args)

View File

@@ -0,0 +1,320 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr_local.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.tasks.punctuation import PunctuationTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.forward_adaptor import ForwardAdaptor
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
def __init__(
self,
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
):
# Build Model
model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
# logging.info(f"Model:\n{model}")
self.punc_list = train_args.punc_list
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
self.punc_list[i] = ""
elif self.punc_list[i] == "?":
self.punc_list[i] = ""
elif self.punc_list[i] == "":
self.period = i
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
token_type=train_args.token_type,
token_list=train_args.token_list,
bpemodel=train_args.bpemodel,
text_cleaner=train_args.cleaner,
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
@torch.no_grad()
def __call__(self, text: Union[list, str], split_size=20):
data = {"text": text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
new_mini_sentence = ""
new_mini_sentence_punc = []
cache_pop_trigger_limit = 200
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
# if len(punctuations) == 0:
# continue
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
words_with_punc.append(self.punc_list[punctuations[i]])
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences) - 1:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "":
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
def inference(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
output_dir: str,
log_level: Union[int, str],
train_config: Optional[str],
model_file: Optional[str],
key_file: Optional[str] = None,
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[List[Any], bytes, str] = None,
cache: List[Any] = None,
param_dict: dict = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
output_dir=output_dir,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
seed=seed,
num_workers=num_workers,
log_level=log_level,
key_file=key_file,
train_config=train_config,
model_file=model_file,
param_dict=param_dict,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
text2punc = Text2Punc(train_config, model_file, device)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
cache: List[Any] = None,
param_dict: dict = None,
):
results = []
split_size = 20
if raw_inputs != None:
line = raw_inputs.strip()
key = "demo"
if line == "":
item = {'key': key, 'value': ""}
results.append(item)
return results
result, _ = text2punc(line)
item = {'key': key, 'value': result}
results.append(item)
return results
for inference_text, _, _ in data_path_and_name_and_type:
with open(inference_text, "r", encoding="utf-8") as fin:
for line in fin:
line = line.strip()
segs = line.split("\t")
if len(segs) != 2:
continue
key = segs[0]
if len(segs[1]) == 0:
continue
result, _ = text2punc(segs[1])
item = {'key': key, 'value': result}
results.append(item)
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path != None:
output_file_name = "infer.out"
Path(output_path).mkdir(parents=True, exist_ok=True)
output_file_path = (Path(output_path) / output_file_name).absolute()
with open(output_file_path, "w", encoding="utf-8") as fout:
for item_i in results:
key_out = item_i["key"]
value_out = item_i["value"]
fout.write(f"{key_out}\t{value_out}\n")
return results
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="Punctuation inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group = parser.add_argument_group("Input data related")
group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
group.add_argument("--raw_inputs", type=str, required=False)
group.add_argument("--cache", type=list, required=False)
group.add_argument("--param_dict", type=dict, required=False)
group.add_argument("--key_file", type=str_or_none)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
# kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr_local.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.tasks.punctuation import PunctuationTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.forward_adaptor import ForwardAdaptor
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
def __init__(
self,
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
):
# Build Model
model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
# logging.info(f"Model:\n{model}")
self.punc_list = train_args.punc_list
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
self.punc_list[i] = ""
elif self.punc_list[i] == "?":
self.punc_list[i] = ""
elif self.punc_list[i] == "":
self.period = i
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
token_type=train_args.token_type,
token_list=train_args.token_list,
bpemodel=train_args.bpemodel,
text_cleaner=train_args.cleaner,
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
print("start decoding!!!")
@torch.no_grad()
def __call__(self, text: Union[list, str], cache: list, split_size=20):
if cache is not None and len(cache) > 0:
precache = "".join(cache)
else:
precache = ""
cache = []
data = {"text": precache + text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
sentence_punc_list = []
sentence_words_list= []
cache_pop_trigger_limit = 200
skip_num = 0
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
"vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
punctuations_np = punctuations.cpu().numpy()
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
sentence_words_list += mini_sentence
assert len(sentence_punc_list) == len(sentence_words_list)
words_with_punc = []
sentence_punc_list_out = []
for i in range(0, len(sentence_words_list)):
if i > 0:
if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
sentence_words_list[i] = " " + sentence_words_list[i]
if skip_num < len(cache):
skip_num += 1
else:
words_with_punc.append(sentence_words_list[i])
if skip_num >= len(cache):
sentence_punc_list_out.append(sentence_punc_list[i])
if sentence_punc_list[i] != "_":
words_with_punc.append(sentence_punc_list[i])
sentence_out = "".join(words_with_punc)
sentenceEnd = -1
for i in range(len(sentence_punc_list) - 2, 1, -1):
if sentence_punc_list[i] == "" or sentence_punc_list[i] == "":
sentenceEnd = i
break
cache_out = sentence_words_list[sentenceEnd + 1 :]
if sentence_out[-1] in self.punc_list:
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
return sentence_out, sentence_punc_list_out, cache_out
def inference(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
output_dir: str,
log_level: Union[int, str],
train_config: Optional[str],
model_file: Optional[str],
key_file: Optional[str] = None,
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[List[Any], bytes, str] = None,
cache: List[Any] = None,
param_dict: dict = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
output_dir=output_dir,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
seed=seed,
num_workers=num_workers,
log_level=log_level,
key_file=key_file,
train_config=train_config,
model_file=model_file,
param_dict=param_dict,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs, cache)
def inference_modelscope(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
#cache: list,
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
text2punc = Text2Punc(train_config, model_file, device)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
cache: List[Any] = None,
param_dict: dict = None,
):
results = []
split_size = 10
cache_in = param_dict["cache"]
if raw_inputs != None:
line = raw_inputs.strip()
key = "demo"
if line == "":
item = {'key': key, 'value': ""}
results.append(item)
return results
result, _, cache = text2punc(line, cache_in)
param_dict["cache"] = cache
item = {'key': key, 'value': result}
results.append(item)
return results
return results
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="Punctuation inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group = parser.add_argument_group("Input data related")
group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
group.add_argument("--raw_inputs", type=str, required=False)
group.add_argument("--cache", type=list, required=False)
group.add_argument("--param_dict", type=dict, required=False)
group.add_argument("--key_file", type=str_or_none)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
# kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,577 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from collections import OrderedDict
import numpy as np
import soundfile
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.tasks.diar import DiarTask
from funasr_local.tasks.asr import ASRTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from scipy.ndimage import median_filter
from funasr_local.utils.misc import statistic_model_parameters
from funasr_local.datasets.iterable_dataset import load_bytes
class Speech2Diarization:
"""Speech2Xvector class
Examples:
>>> import soundfile
>>> import numpy as np
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...}
"""
def __init__(
self,
diar_train_config: Union[Path, str] = None,
diar_model_file: Union[Path, str] = None,
device: Union[str, torch.device] = "cpu",
batch_size: int = 1,
dtype: str = "float32",
streaming: bool = False,
smooth_size: int = 83,
dur_threshold: float = 10,
):
assert check_argument_types()
# TODO: 1. Build Diarization model
diar_model, diar_train_args = DiarTask.build_model_from_file(
config_file=diar_train_config,
model_file=diar_model_file,
device=device
)
logging.info("diar_model: {}".format(diar_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
logging.info("diar_train_args: {}".format(diar_train_args))
diar_model.to(dtype=getattr(torch, dtype)).eval()
self.diar_model = diar_model
self.diar_train_args = diar_train_args
self.token_list = diar_train_args.token_list
self.smooth_size = smooth_size
self.dur_threshold = dur_threshold
self.device = device
self.dtype = dtype
def smooth_multi_labels(self, multi_label):
multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
return multi_label
@staticmethod
def calc_spk_turns(label_arr, spk_list):
turn_list = []
length = label_arr.shape[0]
n_spk = label_arr.shape[1]
for k in range(n_spk):
if spk_list[k] == "None":
continue
in_utt = False
start = 0
for i in range(length):
if label_arr[i, k] == 1 and in_utt is False:
start = i
in_utt = True
if label_arr[i, k] == 0 and in_utt is True:
turn_list.append([spk_list[k], start, i - start])
in_utt = False
if in_utt:
turn_list.append([spk_list[k], start, length - start])
return turn_list
@staticmethod
def seq2arr(seq, vec_dim=8):
def int2vec(x, vec_dim=8, dtype=np.int):
b = ('{:0' + str(vec_dim) + 'b}').format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
# process oov
seq = np.array([int(x) for x in seq])
new_seq = []
for i, x in enumerate(seq):
if x < 2 ** vec_dim:
new_seq.append(x)
else:
idx_list = np.where(seq < 2 ** vec_dim)[0]
idx = np.abs(idx_list - i).argmin()
new_seq.append(seq[idx_list[idx]])
return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
# upsampling outputs to match inputs
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
logits_idx = F.upsample(
logits_idx.unsqueeze(1).float(),
size=(ut, ),
mode="nearest",
).squeeze(1).long()
logits_idx = logits_idx[0].tolist()
pse_labels = [self.token_list[x] for x in logits_idx]
if output_format == "pse_labels":
return pse_labels, None
multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
multi_labels = self.smooth_multi_labels(multi_labels)
if output_format == "binary_labels":
return multi_labels, None
spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
spk_turns = self.calc_spk_turns(multi_labels, spk_list)
results = OrderedDict()
for spk, st, dur in spk_turns:
if spk not in results:
results[spk] = []
if dur > self.dur_threshold:
results[spk].append((st, st+dur))
# sort segments in start time ascending
for spk in results:
results[spk] = sorted(results[spk], key=lambda x: x[0])
return results, pse_labels
@torch.no_grad()
def __call__(
self,
speech: Union[torch.Tensor, np.ndarray],
profile: Union[torch.Tensor, np.ndarray],
output_format: str = "speaker_turn"
):
"""Inference
Args:
speech: Input speech data
profile: Speaker profiles
Returns:
diarization results for each speaker
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if isinstance(profile, np.ndarray):
profile = torch.tensor(profile)
# data: (Nsamples,) -> (1, Nsamples)
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
# lengths: (1,)
speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
batch = {"speech": speech, "speech_lengths": speech_lengths,
"profile": profile, "profile_lengths": profile_lengths}
# a. To device
batch = to_device(batch, device=self.device)
logits = self.diar_model.prediction_forward(**batch)
results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
return results, pse_labels
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
):
"""Build Speech2Xvector instance from the pretrained model.
Args:
model_tag (Optional[str]): Model tag of the pretrained models.
Currently, the tags of espnet_model_zoo are supported.
Returns:
Speech2Xvector: Speech2Xvector instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Diarization(**kwargs)
def inference_modelscope(
diar_train_config: str,
diar_model_file: str,
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
ngpu: int = 0,
seed: int = 0,
num_workers: int = 0,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
smooth_size: int = 83,
dur_threshold: int = 10,
out_format: str = "vad",
param_dict: Optional[dict] = None,
mode: str = "sond",
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("param_dict: {}".format(param_dict))
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2a. Build speech2xvec [Optional]
if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
sv_train_config = param_dict["sv_train_config"]
sv_model_file = param_dict["sv_model_file"]
if "model_dir" in param_dict:
sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config)
sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file)
from funasr_local.bin.sv_inference import Speech2Xvector
speech2xvector_kwargs = dict(
sv_train_config=sv_train_config,
sv_model_file=sv_model_file,
device=device,
dtype=dtype,
streaming=streaming,
embedding_node="resnet1_dense"
)
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
speech2xvector = Speech2Xvector.from_pretrained(
model_tag=model_tag,
**speech2xvector_kwargs,
)
speech2xvector.sv_model.eval()
# 2b. Build speech2diar
speech2diar_kwargs = dict(
diar_train_config=diar_train_config,
diar_model_file=diar_model_file,
device=device,
dtype=dtype,
streaming=streaming,
smooth_size=smooth_size,
dur_threshold=dur_threshold,
)
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
speech2diar = Speech2Diarization.from_pretrained(
model_tag=model_tag,
**speech2diar_kwargs,
)
speech2diar.diar_model.eval()
def output_results_str(results: dict, uttid: str):
rst = []
mid = uttid.rsplit("-", 1)[0]
for key in results:
results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
if out_format == "vad":
for spk, segs in results.items():
rst.append("{} {}".format(spk, segs))
else:
template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
for spk, segs in results.items():
rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
return "\n".join(rst)
def _forward(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
output_dir_v2: Optional[str] = None,
param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, (list, tuple)):
if not isinstance(raw_inputs[0], List):
raw_inputs = [raw_inputs]
assert all([len(example) >= 2 for example in raw_inputs]), \
"The length of test case in raw_inputs must larger than 1 (>=2)."
def prepare_dataset():
for idx, example in enumerate(raw_inputs):
# read waveform file
example = [load_bytes(x) if isinstance(x, bytes) else x
for x in example]
example = [soundfile.read(x)[0] if isinstance(x, str) else x
for x in example]
# convert torch tensor to numpy array
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
for x in example]
speech = example[0]
logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
profile = torch.cat(profile, dim=0)
yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
loader = prepare_dataset()
else:
raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
else:
# 3. Build data-iterator
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=None,
collate_fn=None,
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
# 7. Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
os.makedirs(output_path, exist_ok=True)
output_writer = open("{}/result.txt".format(output_path), "w")
pse_label_writer = open("{}/labels.txt".format(output_path), "w")
logging.info("Start to diarize...")
result_list = []
for idx, (keys, batch) in enumerate(loader):
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
results, pse_labels = speech2diar(**batch)
# Only supporting batch_size==1
key, value = keys[0], output_results_str(results, keys[0])
item = {"key": key, "value": value}
result_list.append(item)
if output_path is not None:
output_writer.write(value)
output_writer.flush()
pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
pse_label_writer.flush()
if idx % 100 == 0:
logging.info("Processing {:5d}: {}".format(idx, key))
if output_path is not None:
output_writer.close()
pse_label_writer.close()
return result_list
return _forward
def inference(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
diar_train_config: Optional[str],
diar_model_file: Optional[str],
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
ngpu: int = 0,
seed: int = 0,
num_workers: int = 1,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
smooth_size: int = 83,
dur_threshold: int = 10,
out_format: str = "vad",
**kwargs,
):
inference_pipeline = inference_modelscope(
diar_train_config=diar_train_config,
diar_model_file=diar_model_file,
output_dir=output_dir,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
seed=seed,
num_workers=num_workers,
log_level=log_level,
key_file=key_file,
model_tag=model_tag,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
smooth_size=smooth_size,
dur_threshold=dur_threshold,
out_format=out_format,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker verification/x-vector extraction",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--diar_train_config",
type=str,
help="diarization training configuration",
)
group.add_argument(
"--diar_model_file",
type=str,
help="diarization model parameter file",
)
group.add_argument(
"--dur_threshold",
type=int,
default=10,
help="The threshold for short segments in number frames"
)
parser.add_argument(
"--smooth_size",
type=int,
default=83,
help="The smoothing window length in number frames"
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
parser.add_argument("--streaming", type=str2bool, default=False)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
logging.info("args: {}".format(kwargs))
if args.output_dir is None:
jobid, n_gpu = 1, 1
gpuid = args.gpuid_list.split(",")[jobid-1]
else:
jobid = int(args.output_dir.split(".")[-1])
n_gpu = len(args.gpuid_list.split(","))
gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
results_list = inference(**kwargs)
for results in results_list:
print("{} {}".format(results["key"], results["value"]))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,443 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
from kaldiio import WriteHelper
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.tasks.sv import SVTask
from funasr_local.tasks.asr import ASRTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils.misc import statistic_model_parameters
class Speech2Xvector:
"""Speech2Xvector class
Examples:
>>> import soundfile
>>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2xvector(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
sv_train_config: Union[Path, str] = None,
sv_model_file: Union[Path, str] = None,
device: str = "cpu",
batch_size: int = 1,
dtype: str = "float32",
streaming: bool = False,
embedding_node: str = "resnet1_dense",
):
assert check_argument_types()
# TODO: 1. Build SV model
sv_model, sv_train_args = SVTask.build_model_from_file(
config_file=sv_train_config,
model_file=sv_model_file,
device=device
)
logging.info("sv_model: {}".format(sv_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
logging.info("sv_train_args: {}".format(sv_train_args))
sv_model.to(dtype=getattr(torch, dtype)).eval()
self.sv_model = sv_model
self.sv_train_args = sv_train_args
self.device = device
self.dtype = dtype
self.embedding_node = embedding_node
@torch.no_grad()
def calculate_embedding(self, speech: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
# data: (Nsamples,) -> (1, Nsamples)
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
# lengths: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
batch = {"speech": speech, "speech_lengths": lengths}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, ilens = self.sv_model.encode(**batch)
# c. Forward Pooling
pooling = self.sv_model.pooling_layer(enc)
# d. Forward Decoder
outputs, embeddings = self.sv_model.decoder(pooling)
if self.embedding_node not in embeddings:
raise ValueError("Required embedding node {} not in {}".format(
self.embedding_node, embeddings.keys()))
return embeddings[self.embedding_node]
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray],
ref_speech: Optional[Union[torch.Tensor, np.ndarray]] = None,
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None]]:
"""Inference
Args:
speech: Input speech data
ref_speech: Reference speech to compare
Returns:
embedding, ref_embedding, similarity_score
"""
assert check_argument_types()
self.sv_model.eval()
embedding = self.calculate_embedding(speech)
ref_emb, score = None, None
if ref_speech is not None:
ref_emb = self.calculate_embedding(ref_speech)
score = torch.cosine_similarity(embedding, ref_emb)
results = (embedding, ref_emb, score)
assert check_return_type(results)
return results
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
):
"""Build Speech2Xvector instance from the pretrained model.
Args:
model_tag (Optional[str]): Model tag of the pretrained models.
Currently, the tags of espnet_model_zoo are supported.
Returns:
Speech2Xvector: Speech2Xvector instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Xvector(**kwargs)
def inference_modelscope(
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
ngpu: int = 1,
seed: int = 0,
num_workers: int = 0,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
sv_train_config: Optional[str] = "sv.yaml",
sv_model_file: Optional[str] = "sv.pb",
model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
embedding_node: str = "resnet1_dense",
sv_threshold: float = 0.9465,
param_dict: Optional[dict] = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("param_dict: {}".format(param_dict))
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2xvector
speech2xvector_kwargs = dict(
sv_train_config=sv_train_config,
sv_model_file=sv_model_file,
device=device,
dtype=dtype,
streaming=streaming,
embedding_node=embedding_node
)
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
speech2xvector = Speech2Xvector.from_pretrained(
model_tag=model_tag,
**speech2xvector_kwargs,
)
speech2xvector.sv_model.eval()
def _forward(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
# 3. Build data-iterator
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=None,
collate_fn=None,
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
# 7 .Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
embd_writer, ref_embd_writer, score_writer = None, None, None
if output_path is not None:
os.makedirs(output_path, exist_ok=True)
embd_writer = WriteHelper("ark,scp:{}/xvector.ark,{}/xvector.scp".format(output_path, output_path))
sv_result_list = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
embedding, ref_embedding, score = speech2xvector(**batch)
# Only supporting batch_size==1
key = keys[0]
normalized_score = 0.0
if score is not None:
score = score.item()
normalized_score = max(score - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
item = {"key": key, "value": normalized_score}
else:
item = {"key": key, "value": embedding.squeeze(0).cpu().numpy()}
sv_result_list.append(item)
if output_path is not None:
embd_writer(key, embedding[0].cpu().numpy())
if ref_embedding is not None:
if ref_embd_writer is None:
ref_embd_writer = WriteHelper(
"ark,scp:{}/ref_xvector.ark,{}/ref_xvector.scp".format(output_path, output_path)
)
score_writer = open(os.path.join(output_path, "score.txt"), "w")
ref_embd_writer(key, ref_embedding[0].cpu().numpy())
score_writer.write("{} {:.6f}\n".format(key, normalized_score))
if output_path is not None:
embd_writer.close()
if ref_embd_writer is not None:
ref_embd_writer.close()
score_writer.close()
return sv_result_list
return _forward
def inference(
output_dir: Optional[str],
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
key_file: Optional[str],
sv_train_config: Optional[str],
sv_model_file: Optional[str],
model_tag: Optional[str],
allow_variable_data_keys: bool = True,
streaming: bool = False,
embedding_node: str = "resnet1_dense",
sv_threshold: float = 0.9465,
**kwargs,
):
inference_pipeline = inference_modelscope(
output_dir=output_dir,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
seed=seed,
num_workers=num_workers,
log_level=log_level,
key_file=key_file,
sv_train_config=sv_train_config,
sv_model_file=sv_model_file,
model_tag=model_tag,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
embedding_node=embedding_node,
sv_threshold=sv_threshold,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker verification/x-vector extraction",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--sv_train_config",
type=str,
help="SV training configuration",
)
group.add_argument(
"--sv_model_file",
type=str,
help="SV model parameter file",
)
group.add_argument(
"--sv_threshold",
type=float,
default=0.9465,
help="The threshold for verification"
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
parser.add_argument("--streaming", type=str2bool, default=False)
parser.add_argument("--embedding_node", type=str, default="resnet1_dense")
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
logging.info("args: {}".format(kwargs))
if args.output_dir is None:
jobid, n_gpu = 1, 1
gpuid = args.gpuid_list.split(",")[jobid-1]
else:
jobid = int(args.output_dir.split(".")[-1])
n_gpu = len(args.gpuid_list.split(","))
gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
results_list = inference(**kwargs)
for results in results_list:
print("{} {}".format(results["key"], results["value"]))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,174 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--njob",
type=int,
default=1,
help="The number of jobs for each gpu",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--sv_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--sv_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global CMVN file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("The inference configuration related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument(
"--sv_threshold",
type=float,
default=0.9465,
help="The threshold for verification"
)
parser.add_argument(
"--embedding_node",
type=str,
default="resnet1_dense",
help="The network node to extract embedding"
)
return parser
def inference_launch(mode, **kwargs):
if mode == "sv":
from funasr_local.bin.sv_inference import inference_modelscope
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
parser.add_argument(
"--mode",
type=str,
default="sv",
help="The decoding mode",
)
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
inference_launch(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,283 @@
#!/usr/bin/env python3
import argparse
from collections import Counter
import logging
from pathlib import Path
import sys
from typing import List
from typing import Optional
from typeguard import check_argument_types
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.cleaner import TextCleaner
from funasr_local.text.phoneme_tokenizer import g2p_choices
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str_or_none
def field2slice(field: Optional[str]) -> slice:
"""Convert field string to slice
Note that field string accepts 1-based integer.
Examples:
>>> field2slice("1-")
slice(0, None, None)
>>> field2slice("1-3")
slice(0, 3, None)
>>> field2slice("-3")
slice(None, 3, None)
"""
field = field.strip()
try:
if "-" in field:
# e.g. "2-" or "2-5" or "-7"
s1, s2 = field.split("-", maxsplit=1)
if s1.strip() == "":
s1 = None
else:
s1 = int(s1)
if s1 == 0:
raise ValueError("1-based string")
if s2.strip() == "":
s2 = None
else:
s2 = int(s2)
else:
# e.g. "2"
s1 = int(field)
s2 = s1 + 1
if s1 == 0:
raise ValueError("must be 1 or more value")
except ValueError:
raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
if s1 is None:
slic = slice(None, s2)
else:
# -1 because of 1-based integer following "cut" command
# e.g "1-3" -> slice(0, 3)
slic = slice(s1 - 1, s2)
return slic
def tokenize(
input: str,
output: str,
field: Optional[str],
delimiter: Optional[str],
token_type: str,
space_symbol: str,
non_linguistic_symbols: Optional[str],
bpemodel: Optional[str],
log_level: str,
write_vocabulary: bool,
vocabulary_size: int,
remove_non_linguistic_symbols: bool,
cutoff: int,
add_symbol: List[str],
cleaner: Optional[str],
g2p: Optional[str],
):
assert check_argument_types()
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if input == "-":
fin = sys.stdin
else:
fin = Path(input).open("r", encoding="utf-8")
if output == "-":
fout = sys.stdout
else:
p = Path(output)
p.parent.mkdir(parents=True, exist_ok=True)
fout = p.open("w", encoding="utf-8")
cleaner = TextCleaner(cleaner)
tokenizer = build_tokenizer(
token_type=token_type,
bpemodel=bpemodel,
delimiter=delimiter,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
g2p_type=g2p,
)
counter = Counter()
if field is not None:
field = field2slice(field)
for line in fin:
line = line.rstrip()
if field is not None:
# e.g. field="2-"
# uttidA hello world!! -> hello world!!
tokens = line.split(delimiter)
tokens = tokens[field]
if delimiter is None:
line = " ".join(tokens)
else:
line = delimiter.join(tokens)
line = cleaner(line)
tokens = tokenizer.text2tokens(line)
if not write_vocabulary:
fout.write(" ".join(tokens) + "\n")
else:
for t in tokens:
counter[t] += 1
if not write_vocabulary:
return
## FIXME
## del duplicate add_symbols in counter
for symbol_and_id in add_symbol:
# e.g symbol="<blank>:0"
try:
symbol, idx = symbol_and_id.split(":")
except ValueError:
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
symbol = symbol.strip()
if symbol in counter:
del counter[symbol]
# ======= write_vocabulary mode from here =======
# Sort by the number of occurrences in descending order
# and filter lower frequency words than cutoff value
words_and_counts = list(
filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
)
# Restrict the vocabulary size
if vocabulary_size > 0:
if vocabulary_size < len(add_symbol):
raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
# Parse the values of --add_symbol
for symbol_and_id in add_symbol:
# e.g symbol="<blank>:0"
try:
symbol, idx = symbol_and_id.split(":")
idx = int(idx)
except ValueError:
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
symbol = symbol.strip()
# e.g. idx=0 -> append as the first symbol
# e.g. idx=-1 -> append as the last symbol
if idx < 0:
idx = len(words_and_counts) + 1 + idx
words_and_counts.insert(idx, (symbol, None))
# Write words
for w, c in words_and_counts:
fout.write(w + "\n")
# Logging
total_count = sum(counter.values())
invocab_count = sum(c for w, c in words_and_counts if c is not None)
logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Tokenize texts",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument(
"--input", "-i", required=True, help="Input text. - indicates sys.stdin"
)
parser.add_argument(
"--output", "-o", required=True, help="Output text. - indicates sys.stdout"
)
parser.add_argument(
"--field",
"-f",
help="The target columns of the input text as 1-based integer. e.g 2-",
)
parser.add_argument(
"--token_type",
"-t",
default="char",
choices=["char", "bpe", "word", "phn"],
help="Token type",
)
parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--remove_non_linguistic_symbols",
type=str2bool,
default=False,
help="Remove non-language-symbols from tokens",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
group = parser.add_argument_group("write_vocabulary mode related")
group.add_argument(
"--write_vocabulary",
type=str2bool,
default=False,
help="Write tokens list instead of tokenized text per line",
)
group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
group.add_argument(
"--cutoff",
default=0,
type=int,
help="cut-off frequency used for write-vocabulary mode",
)
group.add_argument(
"--add_symbol",
type=str,
default=[],
action="append",
help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
tokenize(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,399 @@
import argparse
import logging
from optparse import Option
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.datasets.preprocessor import LMPreprocessor
from funasr_local.tasks.asr import ASRTaskAligner as ASRTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.models.frontend.wav_frontend import WavFrontend
from funasr_local.text.token_id_converter import TokenIDConverter
from funasr_local.utils.timestamp_tools import ts_prediction_lfr6_standard
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
'audio_fs': 16000,
'model_fs': 16000
}
class SpeechText2Timestamp:
def __init__(
self,
timestamp_infer_config: Union[Path, str] = None,
timestamp_model_file: Union[Path, str] = None,
timestamp_cmvn_file: Union[Path, str] = None,
device: str = "cpu",
dtype: str = "float32",
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
tp_model, tp_train_args = ASRTask.build_model_from_file(
timestamp_infer_config, timestamp_model_file, device=device
)
if 'cuda' in device:
tp_model = tp_model.cuda() # force model to cuda
frontend = None
if tp_train_args.frontend is not None:
frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
logging.info("tp_model: {}".format(tp_model))
logging.info("tp_train_args: {}".format(tp_train_args))
tp_model.to(dtype=getattr(torch, dtype)).eval()
logging.info(f"Decoding device={device}, dtype={dtype}")
self.tp_model = tp_model
self.tp_train_args = tp_train_args
token_list = self.tp_model.token_list
self.converter = TokenIDConverter(token_list=token_list)
self.device = device
self.dtype = dtype
self.frontend = frontend
self.encoder_downsampling_factor = 1
if tp_train_args.encoder_conf["input_layer"] == "conv2d":
self.encoder_downsampling_factor = 4
@torch.no_grad()
def __call__(
self,
speech: Union[torch.Tensor, np.ndarray],
speech_lengths: Union[torch.Tensor, np.ndarray] = None,
text_lengths: Union[torch.Tensor, np.ndarray] = None
):
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.tp_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
# lfr_factor = max(1, (feats.size()[-1]//80)-1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, enc_len = self.tp_model.encode(**batch)
if isinstance(enc, tuple):
enc = enc[0]
# c. Forward Predictor
_, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
return us_alphas, us_peaks
def inference(
batch_size: int,
ngpu: int,
log_level: Union[int, str],
data_path_and_name_and_type,
timestamp_infer_config: Optional[str],
timestamp_model_file: Optional[str],
timestamp_cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
num_workers: int = 1,
split_with_space: bool = True,
seg_dict_file: Optional[str] = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
batch_size=batch_size,
ngpu=ngpu,
log_level=log_level,
timestamp_infer_config=timestamp_infer_config,
timestamp_model_file=timestamp_model_file,
timestamp_cmvn_file=timestamp_cmvn_file,
key_file=key_file,
allow_variable_data_keys=allow_variable_data_keys,
output_dir=output_dir,
dtype=dtype,
seed=seed,
num_workers=num_workers,
split_with_space=split_with_space,
seg_dict_file=seg_dict_file,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
batch_size: int,
ngpu: int,
log_level: Union[int, str],
# data_path_and_name_and_type,
timestamp_infer_config: Optional[str],
timestamp_model_file: Optional[str],
timestamp_cmvn_file: Optional[str] = None,
# raw_inputs: Union[np.ndarray, torch.Tensor] = None,
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
num_workers: int = 1,
split_with_space: bool = True,
seg_dict_file: Optional[str] = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2vadsegment
speechtext2timestamp_kwargs = dict(
timestamp_infer_config=timestamp_infer_config,
timestamp_model_file=timestamp_model_file,
timestamp_cmvn_file=timestamp_cmvn_file,
device=device,
dtype=dtype,
)
logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
speechtext2timestamp = SpeechText2Timestamp(**speechtext2timestamp_kwargs)
preprocessor = LMPreprocessor(
train=False,
token_type=speechtext2timestamp.tp_train_args.token_type,
token_list=speechtext2timestamp.tp_train_args.token_list,
bpemodel=None,
text_cleaner=None,
g2p_type=None,
text_name="text",
non_linguistic_symbols=speechtext2timestamp.tp_train_args.non_linguistic_symbols,
split_with_space=split_with_space,
seg_dict_file=seg_dict_file,
)
if output_dir is not None:
writer = DatadirWriter(output_dir)
tp_writer = writer[f"timestamp_prediction"]
# ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
else:
tp_writer = None
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs
):
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
writer = None
if output_path is not None:
writer = DatadirWriter(output_path)
tp_writer = writer[f"timestamp_prediction"]
else:
tp_writer = None
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=preprocessor,
collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
tp_result_list = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
logging.info("timestamp predicting, utt_id: {}".format(keys))
_batch = {'speech':batch['speech'],
'speech_lengths':batch['speech_lengths'],
'text_lengths':batch['text_lengths']}
us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
for batch_id in range(_bs):
key = keys[batch_id]
token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0)
logging.warning(ts_str)
item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
if tp_writer is not None:
tp_writer["tp_sync"][key+'#'] = ts_str
tp_writer["tp_time"][key+'#'] = str(ts_list)
tp_result_list.append(item)
return tp_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="Timestamp Prediction Inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--timestamp_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--timestamp_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--timestamp_cmvn_file",
type=str,
help="Global cmvn file",
)
group = parser.add_argument_group("infer related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument(
"--seg_dict_file",
type=str,
default=None,
help="The batch size for inference",
)
group.add_argument(
"--split_with_space",
type=bool,
default=False,
help="The batch size for inference",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,142 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
def get_parser():
parser = config_argparse.ArgumentParser(
description="Timestamp Prediction Inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--njob",
type=int,
default=1,
help="The number of jobs for each gpu",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--timestamp_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--timestamp_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--timestamp_cmvn_file",
type=str,
help="Global CMVN file",
)
group = parser.add_argument_group("The inference configuration related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
return parser
def inference_launch(mode, **kwargs):
if mode == "tp_norm":
from funasr_local.bin.tp_inference import inference_modelscope
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
parser.add_argument(
"--mode",
type=str,
default="tp_norm",
help="The decoding mode",
)
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
inference_launch(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,575 @@
import argparse
import logging
import os
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import math
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
from funasr_local.modules.subsampling import TooShortUttError
from funasr_local.tasks.vad import VADTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.utils import asr_utils, wav_utils, postprocess_utils
from funasr_local.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
'audio_fs': 16000,
'model_fs': 16000
}
class Speech2VadSegment:
"""Speech2VadSegment class
Examples:
>>> import soundfile
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]
"""
def __init__(
self,
vad_infer_config: Union[Path, str] = None,
vad_model_file: Union[Path, str] = None,
vad_cmvn_file: Union[Path, str] = None,
device: str = "cpu",
batch_size: int = 1,
dtype: str = "float32",
**kwargs,
):
assert check_argument_types()
# 1. Build vad model
vad_model, vad_infer_args = VADTask.build_model_from_file(
vad_infer_config, vad_model_file, device
)
frontend = None
if vad_infer_args.frontend is not None:
frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
logging.info("vad_model: {}".format(vad_model))
logging.info("vad_infer_args: {}".format(vad_infer_args))
vad_model.to(dtype=getattr(torch, dtype)).eval()
self.vad_model = vad_model
self.vad_infer_args = vad_infer_args
self.device = device
self.dtype = dtype
self.frontend = frontend
self.batch_size = batch_size
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
in_cache: Dict[str, torch.Tensor] = dict()
) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
self.frontend.filter_length_max = math.inf
fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
fbanks = to_device(fbanks, device=self.device)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
else:
raise Exception("Need to extract feats first, please configure frontend configuration")
# b. Forward Encoder streaming
t_offset = 0
step = min(feats_len.max(), 6000)
segments = [[]] * self.batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
is_final = True
else:
is_final = False
batch = {
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final": is_final,
"in_cache": in_cache
}
# a. To device
#batch = to_device(batch, device=self.device)
segments_part, in_cache = self.vad_model(**batch)
if segments_part:
for batch_num in range(0, self.batch_size):
segments[batch_num] += segments_part[batch_num]
return fbanks, segments
class Speech2VadSegmentOnline(Speech2VadSegment):
"""Speech2VadSegmentOnline class
Examples:
>>> import soundfile
>>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]
"""
def __init__(self, **kwargs):
super(Speech2VadSegmentOnline, self).__init__(**kwargs)
vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
self.frontend = None
if self.vad_infer_args.frontend is not None:
self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
batch_size = speech.shape[0]
segments = [[]] * batch_size
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final)
fbanks, _ = self.frontend.get_fbank()
else:
raise Exception("Need to extract feats first, please configure frontend configuration")
if feats.shape[0]:
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
waveforms = self.frontend.get_waveforms()
batch = {
"feats": feats,
"waveform": waveforms,
"in_cache": in_cache,
"is_final": is_final,
"max_end_sil": max_end_sil
}
# a. To device
batch = to_device(batch, device=self.device)
segments, in_cache = self.vad_model.forward_online(**batch)
# in_cache.update(batch['in_cache'])
# in_cache = {key: value for key, value in batch['in_cache'].items()}
return fbanks, segments, in_cache
def inference(
batch_size: int,
ngpu: int,
log_level: Union[int, str],
data_path_and_name_and_type,
vad_infer_config: Optional[str],
vad_model_file: Optional[str],
vad_cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
num_workers: int = 1,
online: bool = False,
**kwargs,
):
if not online:
inference_pipeline = inference_modelscope(
batch_size=batch_size,
ngpu=ngpu,
log_level=log_level,
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
key_file=key_file,
allow_variable_data_keys=allow_variable_data_keys,
output_dir=output_dir,
dtype=dtype,
seed=seed,
num_workers=num_workers,
**kwargs,
)
else:
inference_pipeline = inference_modelscope_online(
batch_size=batch_size,
ngpu=ngpu,
log_level=log_level,
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
key_file=key_file,
allow_variable_data_keys=allow_variable_data_keys,
output_dir=output_dir,
dtype=dtype,
seed=seed,
num_workers=num_workers,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
batch_size: int,
ngpu: int,
log_level: Union[int, str],
# data_path_and_name_and_type,
vad_infer_config: Optional[str],
vad_model_file: Optional[str],
vad_cmvn_file: Optional[str] = None,
# raw_inputs: Union[np.ndarray, torch.Tensor] = None,
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
num_workers: int = 1,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2vadsegment
speech2vadsegment_kwargs = dict(
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
device=device,
dtype=dtype,
)
logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = VADTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
ibest_writer = writer[f"1best_recog"]
else:
writer = None
ibest_writer = None
vad_results = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# do vad segment
_, results = speech2vadsegment(**batch)
for i, _ in enumerate(keys):
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
results[i] = json.dumps(results[i])
item = {'key': keys[i], 'value': results[i]}
vad_results.append(item)
if writer is not None:
results[i] = json.loads(results[i])
ibest_writer["text"][keys[i]] = "{}".format(results[i])
return vad_results
return _forward
def inference_modelscope_online(
batch_size: int,
ngpu: int,
log_level: Union[int, str],
# data_path_and_name_and_type,
vad_infer_config: Optional[str],
vad_model_file: Optional[str],
vad_cmvn_file: Optional[str] = None,
# raw_inputs: Union[np.ndarray, torch.Tensor] = None,
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
num_workers: int = 1,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2vadsegment
speech2vadsegment_kwargs = dict(
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
device=device,
dtype=dtype,
)
logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = VADTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
ibest_writer = writer[f"1best_recog"]
else:
writer = None
ibest_writer = None
vad_results = []
batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
is_final = param_dict.get('is_final', False) if param_dict is not None else False
max_end_sil = param_dict.get('max_end_sil', 800) if param_dict is not None else 800
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch['in_cache'] = batch_in_cache
batch['is_final'] = is_final
batch['max_end_sil'] = max_end_sil
# do vad segment
_, results, param_dict['in_cache'] = speech2vadsegment(**batch)
# param_dict['in_cache'] = batch['in_cache']
if results:
for i, _ in enumerate(keys):
if results[i]:
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
results[i] = json.dumps(results[i])
item = {'key': keys[i], 'value': results[i]}
vad_results.append(item)
if writer is not None:
results[i] = json.loads(results[i])
ibest_writer["text"][keys[i]] = "{}".format(results[i])
return vad_results
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="VAD Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--vad_cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--online",
type=str,
help="decoding mode",
)
group = parser.add_argument_group("infer related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
def get_parser():
parser = config_argparse.ArgumentParser(
description="VAD Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--njob",
type=int,
default=1,
help="The number of jobs for each gpu",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--vad_cmvn_file",
type=str,
help="Global CMVN file",
)
group.add_argument(
"--vad_train_config",
type=str,
help="VAD training configuration",
)
group = parser.add_argument_group("The inference configuration related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
return parser
def inference_launch(mode, **kwargs):
if mode == "offline":
from funasr_local.bin.vad_inference import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "online":
from funasr_local.bin.vad_inference import inference_modelscope_online
return inference_modelscope_online(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
parser.add_argument(
"--mode",
type=str,
default="vad",
help="The decoding mode",
)
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
inference_launch(**kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,347 @@
import argparse
import logging
import os
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.fileio.datadir_writer import DatadirWriter
from funasr_local.tasks.vad import VADTask
from funasr_local.torch_utils.device_funcs import to_device
from funasr_local.torch_utils.set_all_random_seed import set_all_random_seed
from funasr_local.utils import config_argparse
from funasr_local.utils.cli_utils import get_commandline_args
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str2triple_str
from funasr_local.utils.types import str_or_none
from funasr_local.models.frontend.wav_frontend import WavFrontendOnline
from funasr_local.models.frontend.wav_frontend import WavFrontend
from funasr_local.bin.vad_inference import Speech2VadSegment
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2VadSegmentOnline(Speech2VadSegment):
"""Speech2VadSegmentOnline class
Examples:
>>> import soundfile
>>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]
"""
def __init__(self, **kwargs):
super(Speech2VadSegmentOnline, self).__init__(**kwargs)
vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
self.frontend = None
if self.vad_infer_args.frontend is not None:
self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
batch_size = speech.shape[0]
segments = [[]] * batch_size
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final)
fbanks, _ = self.frontend.get_fbank()
else:
raise Exception("Need to extract feats first, please configure frontend configuration")
if feats.shape[0]:
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
waveforms = self.frontend.get_waveforms()
batch = {
"feats": feats,
"waveform": waveforms,
"in_cache": in_cache,
"is_final": is_final,
"max_end_sil": max_end_sil
}
# a. To device
batch = to_device(batch, device=self.device)
segments, in_cache = self.vad_model.forward_online(**batch)
# in_cache.update(batch['in_cache'])
# in_cache = {key: value for key, value in batch['in_cache'].items()}
return fbanks, segments, in_cache
def inference(
batch_size: int,
ngpu: int,
log_level: Union[int, str],
data_path_and_name_and_type,
vad_infer_config: Optional[str],
vad_model_file: Optional[str],
vad_cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
num_workers: int = 1,
**kwargs,
):
inference_pipeline = inference_modelscope(
batch_size=batch_size,
ngpu=ngpu,
log_level=log_level,
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
key_file=key_file,
allow_variable_data_keys=allow_variable_data_keys,
output_dir=output_dir,
dtype=dtype,
seed=seed,
num_workers=num_workers,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
batch_size: int,
ngpu: int,
log_level: Union[int, str],
# data_path_and_name_and_type,
vad_infer_config: Optional[str],
vad_model_file: Optional[str],
vad_cmvn_file: Optional[str] = None,
# raw_inputs: Union[np.ndarray, torch.Tensor] = None,
key_file: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
num_workers: int = 1,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2vadsegment
speech2vadsegment_kwargs = dict(
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
device=device,
dtype=dtype,
)
logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = VADTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
ibest_writer = writer[f"1best_recog"]
else:
writer = None
ibest_writer = None
vad_results = []
batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
is_final = param_dict.get('is_final', False) if param_dict is not None else False
max_end_sil = param_dict.get('max_end_sil', 800) if param_dict is not None else 800
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch['in_cache'] = batch_in_cache
batch['is_final'] = is_final
batch['max_end_sil'] = max_end_sil
# do vad segment
_, results, param_dict['in_cache'] = speech2vadsegment(**batch)
# param_dict['in_cache'] = batch['in_cache']
if results:
for i, _ in enumerate(keys):
if results[i]:
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
results[i] = json.dumps(results[i])
item = {'key': keys[i], 'value': results[i]}
vad_results.append(item)
if writer is not None:
results[i] = json.loads(results[i])
ibest_writer["text"][keys[i]] = "{}".format(results[i])
return vad_results
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="VAD Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--vad_cmvn_file",
type=str,
help="Global cmvn file",
)
group = parser.add_argument_group("infer related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

View File

@@ -0,0 +1,135 @@
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.modules.nets_utils import pad_list
class CommonCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
max_sample_size=None
):
assert check_argument_types()
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.max_sample_size = max_sample_size
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
return common_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
def common_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
"""
assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
assert check_return_type(output)
return output
def crop_to_max_size(feature, target_size):
size = len(feature)
diff = size - target_size
if diff <= 0:
return feature
start = np.random.randint(0, diff + 1)
end = size - diff + start
return feature[start:end]
def clipping_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
max_sample_size=None,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
# mainly for pre-training
assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
sizes = [len(s) for s in tensor_list]
if max_sample_size is None:
target_size = min(sizes)
else:
target_size = min(min(sizes), max_sample_size)
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
diff = size - target_size
if diff == 0:
tensor[i] = source
else:
tensor[i] = crop_to_max_size(source, target_size)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
assert check_return_type(output)
return output

View File

@@ -0,0 +1,448 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from abc import ABC
from abc import abstractmethod
import collections
import copy
import functools
import logging
import numbers
import re
from typing import Any
from typing import Callable
from typing import Collection
from typing import Dict
from typing import Mapping
from typing import Tuple
from typing import Union
import h5py
import humanfriendly
# import kaldiio
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.fileio.npy_scp import NpyScpReader
from funasr_local.fileio.rand_gen_dataset import FloatRandomGenerateDataset
from funasr_local.fileio.rand_gen_dataset import IntRandomGenerateDataset
from funasr_local.fileio.read_text import load_num_sequence_text
from funasr_local.fileio.read_text import read_2column_text
from funasr_local.fileio.sound_scp import SoundScpReader
from funasr_local.utils.sized_dict import SizedDict
class AdapterForSoundScpReader(collections.abc.Mapping):
def __init__(self, loader, dtype=None):
assert check_argument_types()
self.loader = loader
self.dtype = dtype
self.rate = None
def keys(self):
return self.loader.keys()
def __len__(self):
return len(self.loader)
def __iter__(self):
return iter(self.loader)
def __getitem__(self, key: str) -> np.ndarray:
retval = self.loader[key]
if isinstance(retval, tuple):
assert len(retval) == 2, len(retval)
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
# sound scp case
rate, array = retval
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
# Extended ark format case
array, rate = retval
else:
raise RuntimeError(
f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
)
if self.rate is not None and self.rate != rate:
raise RuntimeError(
f"Sampling rates are mismatched: {self.rate} != {rate}"
)
self.rate = rate
# Multichannel wave fie
# array: (NSample, Channel) or (Nsample)
if self.dtype is not None:
array = array.astype(self.dtype)
else:
# Normal ark case
assert isinstance(retval, np.ndarray), type(retval)
array = retval
if self.dtype is not None:
array = array.astype(self.dtype)
assert isinstance(array, np.ndarray), type(array)
return array
class H5FileWrapper:
def __init__(self, path: str):
self.path = path
self.h5_file = h5py.File(path, "r")
def __repr__(self) -> str:
return str(self.h5_file)
def __len__(self) -> int:
return len(self.h5_file)
def __iter__(self):
return iter(self.h5_file)
def __getitem__(self, key) -> np.ndarray:
value = self.h5_file[key]
return value[()]
def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
# The file is as follows:
# utterance_id_A /some/where/a.wav
# utterance_id_B /some/where/a.flac
# NOTE(kamo): SoundScpReader doesn't support pipe-fashion
# like Kaldi e.g. "cat a.wav |".
# NOTE(kamo): The audio signal is normalized to [-1,1] range.
loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
# SoundScpReader.__getitem__() returns Tuple[int, ndarray],
# but ndarray is desired, so Adapter class is inserted here
return AdapterForSoundScpReader(loader, float_dtype)
def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
return AdapterForSoundScpReader(loader, float_dtype)
def rand_int_loader(filepath, loader_type):
# e.g. rand_int_3_10
try:
low, high = map(int, loader_type[len("rand_int_") :].split("_"))
except ValueError:
raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
return IntRandomGenerateDataset(filepath, low, high)
DATA_TYPES = {
"sound": dict(
func=sound_loader,
kwargs=["dest_sample_rate","float_dtype"],
help="Audio format types which supported by sndfile wav, flac, etc."
"\n\n"
" utterance_id_a a.wav\n"
" utterance_id_b b.wav\n"
" ...",
),
"kaldi_ark": dict(
func=kaldi_loader,
kwargs=["max_cache_fd"],
help="Kaldi-ark file type."
"\n\n"
" utterance_id_A /some/where/a.ark:123\n"
" utterance_id_B /some/where/a.ark:456\n"
" ...",
),
"npy": dict(
func=NpyScpReader,
kwargs=[],
help="Npy file format."
"\n\n"
" utterance_id_A /some/where/a.npy\n"
" utterance_id_B /some/where/b.npy\n"
" ...",
),
"text_int": dict(
func=functools.partial(load_num_sequence_text, loader_type="text_int"),
kwargs=[],
help="A text file in which is written a sequence of interger numbers "
"separated by space."
"\n\n"
" utterance_id_A 12 0 1 3\n"
" utterance_id_B 3 3 1\n"
" ...",
),
"csv_int": dict(
func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
kwargs=[],
help="A text file in which is written a sequence of interger numbers "
"separated by comma."
"\n\n"
" utterance_id_A 100,80\n"
" utterance_id_B 143,80\n"
" ...",
),
"text_float": dict(
func=functools.partial(load_num_sequence_text, loader_type="text_float"),
kwargs=[],
help="A text file in which is written a sequence of float numbers "
"separated by space."
"\n\n"
" utterance_id_A 12. 3.1 3.4 4.4\n"
" utterance_id_B 3. 3.12 1.1\n"
" ...",
),
"csv_float": dict(
func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
kwargs=[],
help="A text file in which is written a sequence of float numbers "
"separated by comma."
"\n\n"
" utterance_id_A 12.,3.1,3.4,4.4\n"
" utterance_id_B 3.,3.12,1.1\n"
" ...",
),
"text": dict(
func=read_2column_text,
kwargs=[],
help="Return text as is. The text must be converted to ndarray "
"by 'preprocess'."
"\n\n"
" utterance_id_A hello world\n"
" utterance_id_B foo bar\n"
" ...",
),
"hdf5": dict(
func=H5FileWrapper,
kwargs=[],
help="A HDF5 file which contains arrays at the first level or the second level."
" >>> f = h5py.File('file.h5')\n"
" >>> array1 = f['utterance_id_A']\n"
" >>> array2 = f['utterance_id_B']\n",
),
"rand_float": dict(
func=FloatRandomGenerateDataset,
kwargs=[],
help="Generate random float-ndarray which has the given shapes "
"in the file."
"\n\n"
" utterance_id_A 3,4\n"
" utterance_id_B 10,4\n"
" ...",
),
"rand_int_\\d+_\\d+": dict(
func=rand_int_loader,
kwargs=["loader_type"],
help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
"shapes in the path. "
"Give the lower and upper value by the file type. e.g. "
"rand_int_0_10 -> Generate integers from 0 to 10."
"\n\n"
" utterance_id_A 3,4\n"
" utterance_id_B 10,4\n"
" ...",
),
}
class AbsDataset(Dataset, ABC):
@abstractmethod
def has_name(self, name) -> bool:
raise NotImplementedError
@abstractmethod
def names(self) -> Tuple[str, ...]:
raise NotImplementedError
@abstractmethod
def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
raise NotImplementedError
class ESPnetDataset(AbsDataset):
"""Pytorch Dataset class for ESPNet.
Examples:
>>> dataset = ESPnetDataset([('wav.scp', 'input', 'sound'),
... ('token_int', 'output', 'text_int')],
... )
... uttid, data = dataset['uttid']
{'input': per_utt_array, 'output': per_utt_array}
"""
def __init__(
self,
path_name_type_list: Collection[Tuple[str, str, str]],
preprocess: Callable[
[str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
] = None,
float_dtype: str = "float32",
int_dtype: str = "long",
max_cache_size: Union[float, int, str] = 0.0,
max_cache_fd: int = 0,
dest_sample_rate: int = 16000,
):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
)
path_name_type_list = copy.deepcopy(path_name_type_list)
self.preprocess = preprocess
self.float_dtype = float_dtype
self.int_dtype = int_dtype
self.max_cache_fd = max_cache_fd
self.dest_sample_rate = dest_sample_rate
self.loader_dict = {}
self.debug_info = {}
for path, name, _type in path_name_type_list:
if name in self.loader_dict:
raise RuntimeError(f'"{name}" is duplicated for data-key')
loader = self._build_loader(path, _type)
self.loader_dict[name] = loader
self.debug_info[name] = path, _type
if len(self.loader_dict[name]) == 0:
raise RuntimeError(f"{path} has no samples")
# TODO(kamo): Should check consistency of each utt-keys?
if isinstance(max_cache_size, str):
max_cache_size = humanfriendly.parse_size(max_cache_size)
self.max_cache_size = max_cache_size
if max_cache_size > 0:
self.cache = SizedDict(shared=True)
else:
self.cache = None
def _build_loader(
self, path: str, loader_type: str
) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
"""Helper function to instantiate Loader.
Args:
path: The file path
loader_type: loader_type. sound, npy, text_int, text_float, etc
"""
for key, dic in DATA_TYPES.items():
# e.g. loader_type="sound"
# -> return DATA_TYPES["sound"]["func"](path)
if re.match(key, loader_type):
kwargs = {}
for key2 in dic["kwargs"]:
if key2 == "loader_type":
kwargs["loader_type"] = loader_type
elif key2 == "dest_sample_rate" and loader_type=="sound":
kwargs["dest_sample_rate"] = self.dest_sample_rate
elif key2 == "float_dtype":
kwargs["float_dtype"] = self.float_dtype
elif key2 == "int_dtype":
kwargs["int_dtype"] = self.int_dtype
elif key2 == "max_cache_fd":
kwargs["max_cache_fd"] = self.max_cache_fd
else:
raise RuntimeError(f"Not implemented keyword argument: {key2}")
func = dic["func"]
try:
return func(path, **kwargs)
except Exception:
if hasattr(func, "__name__"):
name = func.__name__
else:
name = str(func)
logging.error(f"An error happened with {name}({path})")
raise
else:
raise RuntimeError(f"Not supported: loader_type={loader_type}")
def has_name(self, name) -> bool:
return name in self.loader_dict
def names(self) -> Tuple[str, ...]:
return tuple(self.loader_dict)
def __iter__(self):
return iter(next(iter(self.loader_dict.values())))
def __repr__(self):
_mes = self.__class__.__name__
_mes += "("
for name, (path, _type) in self.debug_info.items():
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
_mes += f"\n preprocess: {self.preprocess})"
return _mes
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
assert check_argument_types()
# Change integer-id to string-id
if isinstance(uid, int):
d = next(iter(self.loader_dict.values()))
uid = list(d)[uid]
if self.cache is not None and uid in self.cache:
data = self.cache[uid]
return uid, data
data = {}
# 1. Load data from each loaders
for name, loader in self.loader_dict.items():
try:
value = loader[uid]
if isinstance(value, (list, tuple)):
value = np.array(value)
if not isinstance(
value, (np.ndarray, torch.Tensor, str, numbers.Number)
):
raise TypeError(
f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
)
except Exception:
path, _type = self.debug_info[name]
logging.error(
f"Error happened with path={path}, type={_type}, id={uid}"
)
raise
# torch.Tensor is converted to ndarray
if isinstance(value, torch.Tensor):
value = value.numpy()
elif isinstance(value, numbers.Number):
value = np.array([value])
data[name] = value
# 2. [Option] Apply preprocessing
# e.g. funasr_local.train.preprocessor:CommonPreprocessor
if self.preprocess is not None:
data = self.preprocess(uid, data)
# 3. Force data-precision
for name in data:
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f"All values must be converted to np.ndarray object "
f'by preprocessing, but "{name}" is still {type(value)}.'
)
# Cast to desired type
if value.dtype.kind == "f":
value = value.astype(self.float_dtype)
elif value.dtype.kind == "i":
value = value.astype(self.int_dtype)
else:
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
data[name] = value
if self.cache is not None and self.cache.size < self.max_cache_size:
self.cache[uid] = data
retval = uid, data
assert check_return_type(retval)
return retval

View File

@@ -0,0 +1,388 @@
"""Iterable dataset module."""
import copy
from io import StringIO
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import Iterator
from typing import Tuple
from typing import Union
from typing import List
# import kaldiio
import numpy as np
import torch
import torchaudio
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
from funasr_local.datasets.dataset import ESPnetDataset
SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
def load_kaldi(input):
retval = kaldiio.load_mat(input)
if isinstance(retval, tuple):
assert len(retval) == 2, len(retval)
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
# sound scp case
rate, array = retval
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
# Extended ark format case
array, rate = retval
else:
raise RuntimeError(f"Unexpected type: {type(retval[0])}, {type(retval[1])}")
# Multichannel wave fie
# array: (NSample, Channel) or (Nsample)
else:
# Normal ark case
assert isinstance(retval, np.ndarray), type(retval)
array = retval
return array
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def load_pcm(input):
with open(input,"rb") as f:
bytes = f.read()
return load_bytes(bytes)
DATA_TYPES = {
"sound": lambda x: torchaudio.load(x)[0].numpy(),
"pcm": load_pcm,
"kaldi_ark": load_kaldi,
"bytes": load_bytes,
"waveform": lambda x: x,
"npy": np.load,
"text_int": lambda x: np.loadtxt(
StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
),
"csv_int": lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=","),
"text_float": lambda x: np.loadtxt(
StringIO(x), ndmin=1, dtype=np.float32, delimiter=" "
),
"csv_float": lambda x: np.loadtxt(
StringIO(x), ndmin=1, dtype=np.float32, delimiter=","
),
"text": lambda x: x,
}
class IterableESPnetDataset(IterableDataset):
"""Pytorch Dataset class for ESPNet.
Examples:
>>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
... ('token_int', 'output', 'text_int')],
... )
>>> for uid, data in dataset:
... data
{'input': per_utt_array, 'output': per_utt_array}
"""
def __init__(
self,
path_name_type_list: Collection[Tuple[any, str, str]],
preprocess: Callable[
[str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
] = None,
float_dtype: str = "float32",
fs: dict = None,
mc: bool = False,
int_dtype: str = "long",
key_file: str = None,
):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
)
path_name_type_list = copy.deepcopy(path_name_type_list)
self.preprocess = preprocess
self.float_dtype = float_dtype
self.int_dtype = int_dtype
self.key_file = key_file
self.fs = fs
self.mc = mc
self.debug_info = {}
non_iterable_list = []
self.path_name_type_list = []
if not isinstance(path_name_type_list[0], (Tuple, List)):
path = path_name_type_list[0]
name = path_name_type_list[1]
_type = path_name_type_list[2]
self.debug_info[name] = path, _type
if _type not in DATA_TYPES:
non_iterable_list.append((path, name, _type))
else:
self.path_name_type_list.append((path, name, _type))
else:
for path, name, _type in path_name_type_list:
self.debug_info[name] = path, _type
if _type not in DATA_TYPES:
non_iterable_list.append((path, name, _type))
else:
self.path_name_type_list.append((path, name, _type))
if len(non_iterable_list) != 0:
# Some types doesn't support iterable mode
self.non_iterable_dataset = ESPnetDataset(
path_name_type_list=non_iterable_list,
preprocess=preprocess,
float_dtype=float_dtype,
int_dtype=int_dtype,
)
else:
self.non_iterable_dataset = None
self.apply_utt2category = False
def has_name(self, name) -> bool:
return name in self.debug_info
def names(self) -> Tuple[str, ...]:
return tuple(self.debug_info)
def __repr__(self):
_mes = self.__class__.__name__
_mes += "("
for name, (path, _type) in self.debug_info.items():
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
_mes += f"\n preprocess: {self.preprocess})"
return _mes
def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
count = 0
if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
linenum = len(self.path_name_type_list)
data = {}
for i in range(linenum):
value = self.path_name_type_list[i][0]
uid = 'utt_id'
name = self.path_name_type_list[i][1]
_type = self.path_name_type_list[i][2]
func = DATA_TYPES[_type]
array = func(value)
if self.fs is not None and (name == "speech" or name == "ref_speech"):
audio_fs = self.fs["audio_fs"]
model_fs = self.fs["model_fs"]
if audio_fs is not None and model_fs is not None:
array = torch.from_numpy(array)
array = array.unsqueeze(0)
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
array = array.squeeze(0).numpy()
data[name] = array
if self.preprocess is not None:
data = self.preprocess(uid, data)
for name in data:
count += 1
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f'All values must be converted to np.ndarray object '
f'by preprocessing, but "{name}" is still {type(value)}.')
# Cast to desired type
if value.dtype.kind == 'f':
value = value.astype(self.float_dtype)
elif value.dtype.kind == 'i':
value = value.astype(self.int_dtype)
else:
raise NotImplementedError(
f'Not supported dtype: {value.dtype}')
data[name] = value
yield uid, data
elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
linenum = len(self.path_name_type_list)
data = {}
for i in range(linenum):
value = self.path_name_type_list[i][0]
uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0]
name = self.path_name_type_list[i][1]
_type = self.path_name_type_list[i][2]
if _type == "sound":
audio_type = os.path.basename(value).lower()
if audio_type.rfind(".pcm") >= 0:
_type = "pcm"
func = DATA_TYPES[_type]
array = func(value)
if self.fs is not None and (name == "speech" or name == "ref_speech"):
audio_fs = self.fs["audio_fs"]
model_fs = self.fs["model_fs"]
if audio_fs is not None and model_fs is not None:
array = torch.from_numpy(array)
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
array = array.numpy()
if _type == "sound":
if self.mc:
data[name] = array.transpose((1, 0))
else:
data[name] = array[0]
else:
data[name] = array
if self.preprocess is not None:
data = self.preprocess(uid, data)
for name in data:
count += 1
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f'All values must be converted to np.ndarray object '
f'by preprocessing, but "{name}" is still {type(value)}.')
# Cast to desired type
if value.dtype.kind == 'f':
value = value.astype(self.float_dtype)
elif value.dtype.kind == 'i':
value = value.astype(self.int_dtype)
else:
raise NotImplementedError(
f'Not supported dtype: {value.dtype}')
data[name] = value
yield uid, data
else:
if self.key_file is not None:
uid_iter = (
line.rstrip().split(maxsplit=1)[0]
for line in open(self.key_file, encoding="utf-8")
)
elif len(self.path_name_type_list) != 0:
uid_iter = (
line.rstrip().split(maxsplit=1)[0]
for line in open(self.path_name_type_list[0][0], encoding="utf-8")
)
else:
uid_iter = iter(self.non_iterable_dataset)
files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
worker_info = torch.utils.data.get_worker_info()
linenum = 0
for count, uid in enumerate(uid_iter, 1):
# If num_workers>=1, split keys
if worker_info is not None:
if (count - 1) % worker_info.num_workers != worker_info.id:
continue
# 1. Read a line from each file
while True:
keys = []
values = []
for f in files:
linenum += 1
try:
line = next(f)
except StopIteration:
raise RuntimeError(f"{uid} is not found in the files")
sps = line.rstrip().split(maxsplit=1)
if len(sps) != 2:
raise RuntimeError(
f"This line doesn't include a space:"
f" {f}:L{linenum}: {line})"
)
key, value = sps
keys.append(key)
values.append(value)
for k_idx, k in enumerate(keys):
if k != keys[0]:
raise RuntimeError(
f"Keys are mismatched. Text files (idx={k_idx}) is "
f"not sorted or not having same keys at L{linenum}"
)
# If the key is matched, break the loop
if len(keys) == 0 or keys[0] == uid:
break
# 2. Load the entry from each line and create a dict
data = {}
# 2.a. Load data streamingly
for value, (path, name, _type) in zip(values, self.path_name_type_list):
if _type == "sound":
audio_type = os.path.basename(value).lower()
if audio_type.rfind(".pcm") >= 0:
_type = "pcm"
func = DATA_TYPES[_type]
# Load entry
array = func(value)
if self.fs is not None and name == "speech":
audio_fs = self.fs["audio_fs"]
model_fs = self.fs["model_fs"]
if audio_fs is not None and model_fs is not None:
array = torch.from_numpy(array)
array = torchaudio.transforms.Resample(orig_freq=audio_fs,
new_freq=model_fs)(array)
array = array.numpy()
if _type == "sound":
if self.mc:
data[name] = array.transpose((1, 0))
else:
data[name] = array[0]
else:
data[name] = array
if self.non_iterable_dataset is not None:
# 2.b. Load data from non-iterable dataset
_, from_non_iterable = self.non_iterable_dataset[uid]
data.update(from_non_iterable)
# 3. [Option] Apply preprocessing
# e.g. funasr_local.train.preprocessor:CommonPreprocessor
if self.preprocess is not None:
data = self.preprocess(uid, data)
# 4. Force data-precision
for name in data:
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f"All values must be converted to np.ndarray object "
f'by preprocessing, but "{name}" is still {type(value)}.'
)
# Cast to desired type
if value.dtype.kind == "f":
value = value.astype(self.float_dtype)
elif value.dtype.kind == "i":
value = value.astype(self.int_dtype)
else:
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
data[name] = value
yield uid, data
if count == 0:
raise RuntimeError("No iteration")

View File

@@ -0,0 +1,349 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
"""Iterable dataset module."""
import copy
from io import StringIO
from pathlib import Path
from typing import Callable, Collection, Dict, Iterator, Tuple, Union
import kaldiio
import numpy as np
import soundfile
import torch
from funasr_local.datasets.dataset import ESPnetDataset
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
from funasr_local.utils import wav_utils
def load_kaldi(input):
retval = kaldiio.load_mat(input)
if isinstance(retval, tuple):
assert len(retval) == 2, len(retval)
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
# sound scp case
rate, array = retval
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
# Extended ark format case
array, rate = retval
else:
raise RuntimeError(
f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
# Multichannel wave fie
# array: (NSample, Channel) or (Nsample)
else:
# Normal ark case
assert isinstance(retval, np.ndarray), type(retval)
array = retval
return array
DATA_TYPES = {
'sound':
lambda x: soundfile.read(x)[0],
'kaldi_ark':
load_kaldi,
'npy':
np.load,
'text_int':
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
'csv_int':
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
'text_float':
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
),
'csv_float':
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
),
'text':
lambda x: x,
}
class IterableESPnetDatasetModelScope(IterableDataset):
"""Pytorch Dataset class for ESPNet.
Examples:
>>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
... ('token_int', 'output', 'text_int')],
... )
>>> for uid, data in dataset:
... data
{'input': per_utt_array, 'output': per_utt_array}
"""
def __init__(self,
path_name_type_list: Collection[Tuple[any, str, str]],
preprocess: Callable[[str, Dict[str, np.ndarray]],
Dict[str, np.ndarray]] = None,
float_dtype: str = 'float32',
int_dtype: str = 'long',
key_file: str = None,
sample_rate: Union[dict, int] = 16000):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"')
self.preprocess = preprocess
self.float_dtype = float_dtype
self.int_dtype = int_dtype
self.key_file = key_file
self.sample_rate = sample_rate
self.debug_info = {}
non_iterable_list = []
self.path_name_type_list = []
path_list = path_name_type_list[0]
name = path_name_type_list[1]
_type = path_name_type_list[2]
if name in self.debug_info:
raise RuntimeError(f'"{name}" is duplicated for data-key')
self.debug_info[name] = path_list, _type
# for path, name, _type in path_name_type_list:
for path in path_list:
self.path_name_type_list.append((path, name, _type))
if len(non_iterable_list) != 0:
# Some types doesn't support iterable mode
self.non_iterable_dataset = ESPnetDataset(
path_name_type_list=non_iterable_list,
preprocess=preprocess,
float_dtype=float_dtype,
int_dtype=int_dtype,
)
else:
self.non_iterable_dataset = None
self.apply_utt2category = False
def has_name(self, name) -> bool:
return name in self.debug_info
def names(self) -> Tuple[str, ...]:
return tuple(self.debug_info)
def __repr__(self):
_mes = self.__class__.__name__
_mes += '('
for name, (path, _type) in self.debug_info.items():
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
_mes += f'\n preprocess: {self.preprocess})'
return _mes
def __iter__(
self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
torch.set_printoptions(profile='default')
count = len(self.path_name_type_list)
for idx in range(count):
# 2. Load the entry from each line and create a dict
data = {}
# 2.a. Load data streamingly
# value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
value = self.path_name_type_list[idx][0]['file']
uid = self.path_name_type_list[idx][0]['key']
# name: speech
name = self.path_name_type_list[idx][1]
_type = self.path_name_type_list[idx][2]
func = DATA_TYPES[_type]
array = func(value)
# 2.b. audio resample
if _type == 'sound':
audio_sr: int = 16000
model_sr: int = 16000
if isinstance(self.sample_rate, int):
model_sr = self.sample_rate
else:
if 'audio_sr' in self.sample_rate:
audio_sr = self.sample_rate['audio_sr']
if 'model_sr' in self.sample_rate:
model_sr = self.sample_rate['model_sr']
array = wav_utils.torch_resample(array, audio_sr, model_sr)
# array: [ 1.25122070e-03 ... ]
data[name] = array
# 3. [Option] Apply preprocessing
# e.g. espnet2.train.preprocessor:CommonPreprocessor
if self.preprocess is not None:
data = self.preprocess(uid, data)
# data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
# 4. Force data-precision
for name in data:
# value is np.ndarray data
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f'All values must be converted to np.ndarray object '
f'by preprocessing, but "{name}" is still {type(value)}.'
)
# Cast to desired type
if value.dtype.kind == 'f':
value = value.astype(self.float_dtype)
elif value.dtype.kind == 'i':
value = value.astype(self.int_dtype)
else:
raise NotImplementedError(
f'Not supported dtype: {value.dtype}')
data[name] = value
yield uid, data
if count == 0:
raise RuntimeError('No iteration')
class IterableESPnetBytesModelScope(IterableDataset):
"""Pytorch audio bytes class for ESPNet.
Examples:
>>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
... ('token_int', 'output', 'text_int')],
... )
>>> for uid, data in dataset:
... data
{'input': per_utt_array, 'output': per_utt_array}
"""
def __init__(self,
path_name_type_list: Collection[Tuple[any, str, str]],
preprocess: Callable[[str, Dict[str, np.ndarray]],
Dict[str, np.ndarray]] = None,
float_dtype: str = 'float32',
int_dtype: str = 'long',
key_file: str = None,
sample_rate: Union[dict, int] = 16000):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"')
self.preprocess = preprocess
self.float_dtype = float_dtype
self.int_dtype = int_dtype
self.key_file = key_file
self.sample_rate = sample_rate
self.debug_info = {}
non_iterable_list = []
self.path_name_type_list = []
audio_data = path_name_type_list[0]
name = path_name_type_list[1]
_type = path_name_type_list[2]
if name in self.debug_info:
raise RuntimeError(f'"{name}" is duplicated for data-key')
self.debug_info[name] = audio_data, _type
self.path_name_type_list.append((audio_data, name, _type))
if len(non_iterable_list) != 0:
# Some types doesn't support iterable mode
self.non_iterable_dataset = ESPnetDataset(
path_name_type_list=non_iterable_list,
preprocess=preprocess,
float_dtype=float_dtype,
int_dtype=int_dtype,
)
else:
self.non_iterable_dataset = None
self.apply_utt2category = False
if float_dtype == 'float32':
self.np_dtype = np.float32
def has_name(self, name) -> bool:
return name in self.debug_info
def names(self) -> Tuple[str, ...]:
return tuple(self.debug_info)
def __repr__(self):
_mes = self.__class__.__name__
_mes += '('
for name, (path, _type) in self.debug_info.items():
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
_mes += f'\n preprocess: {self.preprocess})'
return _mes
def __iter__(
self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
torch.set_printoptions(profile='default')
# 2. Load the entry from each line and create a dict
data = {}
# 2.a. Load data streamingly
value = self.path_name_type_list[0][0]
uid = 'pcm_data'
# name: speech
name = self.path_name_type_list[0][1]
_type = self.path_name_type_list[0][2]
func = DATA_TYPES[_type]
# array: [ 1.25122070e-03 ... ]
# data[name] = np.frombuffer(value, dtype=self.np_dtype)
# 2.b. byte(PCM16) to float32
middle_data = np.frombuffer(value, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
dtype=self.np_dtype)
# 2.c. audio resample
if _type == 'sound':
audio_sr: int = 16000
model_sr: int = 16000
if isinstance(self.sample_rate, int):
model_sr = self.sample_rate
else:
if 'audio_sr' in self.sample_rate:
audio_sr = self.sample_rate['audio_sr']
if 'model_sr' in self.sample_rate:
model_sr = self.sample_rate['model_sr']
array = wav_utils.torch_resample(array, audio_sr, model_sr)
data[name] = array
# 3. [Option] Apply preprocessing
# e.g. espnet2.train.preprocessor:CommonPreprocessor
if self.preprocess is not None:
data = self.preprocess(uid, data)
# data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
# 4. Force data-precision
for name in data:
# value is np.ndarray data
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f'All values must be converted to np.ndarray object '
f'by preprocessing, but "{name}" is still {type(value)}.')
# Cast to desired type
if value.dtype.kind == 'f':
value = value.astype(self.float_dtype)
elif value.dtype.kind == 'i':
value = value.astype(self.int_dtype)
else:
raise NotImplementedError(
f'Not supported dtype: {value.dtype}')
data[name] = value
yield uid, data

View File

@@ -0,0 +1,96 @@
import logging
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import sentencepiece as spm
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr_local.datasets.large_datasets.dataset import Dataset
from funasr_local.iterators.abs_iter_factory import AbsIterFactory
from funasr_local.text.abs_tokenizer import AbsTokenizer
def read_symbol_table(symbol_table_file):
if isinstance(symbol_table_file, str):
symbol_table = {}
with open(symbol_table_file, "r", encoding="utf8") as fin:
for i, line in enumerate(fin):
char = line.strip()
symbol_table[char] = i
else:
assert isinstance(symbol_table_file, list)
symbol_table = {}
for i, char in enumerate(symbol_table_file):
symbol_table[char] = i
return symbol_table
def load_seg_dict(seg_dict_file):
seg_dict = {}
assert isinstance(seg_dict_file, str)
with open(seg_dict_file, "r", encoding="utf8") as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
assert check_argument_types()
self.model = str(model)
self.sp = None
def __repr__(self):
return f'{self.__class__.__name__}(model="{self.model}")'
def _build_sentence_piece_processor(self):
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.model)
def text2tokens(self, line: str) -> List[str]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
def tokens2text(self, tokens: Iterable[str]) -> str:
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
class ArkDataLoader(AbsIterFactory):
def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
bpemodel_file=None, mode="train"):
symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
if seg_dict_file is not None:
seg_dict = load_seg_dict(seg_dict_file)
else:
seg_dict = None
if punc_dict_file is not None:
punc_dict = read_symbol_table(punc_dict_file)
else:
punc_dict = None
self.dataset_conf = dataset_conf
self.frontend_conf = frontend_conf
logging.info("dataloader config: {}".format(self.dataset_conf))
batch_mode = self.dataset_conf.get("batch_mode", "padding")
if bpemodel_file is not None:
bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
else:
bpe_tokenizer = None
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True):
self.dataset.set_epoch(epoch)
data_loader = DataLoader(self.dataset,
batch_size=None,
pin_memory=True,
num_workers=self.dataset_conf.get("num_workers", 8))
return data_loader

View File

@@ -0,0 +1,213 @@
import random
from itertools import count
from functools import partial
from torch.utils.data import IterableDataset
from funasr_local.datasets.large_datasets.datapipes.map import MapperIterDataPipe
tiebreaker = count()
def _default_len_fn(token):
return len(token), next(tiebreaker)
def _token_len_fn(token, len_fn):
return len_fn(token), next(tiebreaker), token
class MaxTokenBucketizerIterDataPipe(IterableDataset):
def __init__(
self,
datapipe,
batch_size=8000,
len_fn=_default_len_fn,
buffer_size=10240,
sort_size=500,
batch_mode="padding",
):
assert batch_size > 0, "Batch size is required to be larger than 0!"
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
assert sort_size > 0, "Sort size is required to be larger than 0!"
datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
self.datapipe = datapipe
self.batch_size = batch_size
self.buffer_size = buffer_size
self.sort_size = sort_size
self.batch_mode = batch_mode
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
buffer = []
batch = []
bucket = []
max_lengths = 0
min_lengths = 999999
batch_lengths = 0
if self.batch_mode == "clipping":
assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
buffer = []
if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
buffer = []
if bucket:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
if batch:
yield batch
else:
if self.buffer_size == -1:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
buffer.sort()
for sample in buffer:
length, _, token = sample
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
bucket.append(batch)
batch = []
max_lengths = length
batch.append(token)
random.shuffle(bucket)
if bucket:
for batch_sample in bucket:
yield batch_sample
if batch:
yield batch
elif self.buffer_size == 0:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
length, _, token = d
if length > self.batch_size:
continue
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
if batch:
yield batch
else:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
buffer = []
if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
buffer = []
if bucket:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
if batch:
yield batch

View File

@@ -0,0 +1,24 @@
from torch.utils.data import IterableDataset
def default_fn(data):
return data
class FilterIterDataPipe(IterableDataset):
def __init__(self,
datapipe,
fn=default_fn):
self.datapipe = datapipe
self.fn = fn
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
assert callable(self.fn)
for data in self.datapipe:
if self.fn(data):
yield data
else:
continue

View File

@@ -0,0 +1,22 @@
from torch.utils.data import IterableDataset
def default_fn(data):
return data
class MapperIterDataPipe(IterableDataset):
def __init__(self,
datapipe,
fn=default_fn):
self.datapipe = datapipe
self.fn = fn
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
assert callable(self.fn)
for data in self.datapipe:
yield self.fn(data)

View File

@@ -0,0 +1,212 @@
import os
import random
import numpy
from functools import partial
import torch
import torchaudio
import torch.distributed as dist
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr_local.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr_local.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr_local.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr_local.datasets.large_datasets.utils.filter import filter
from funasr_local.datasets.large_datasets.utils.padding import padding
from funasr_local.datasets.large_datasets.utils.clipping import clipping
from funasr_local.datasets.large_datasets.utils.tokenize import tokenize
def read_lists(list_file):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
parts = line.strip()
lists.append(parts)
return lists
class AudioDataset(IterableDataset):
def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
self.scp_lists = scp_lists
self.data_names = data_names
self.data_types = data_types
self.frontend_conf = frontend_conf
self.shuffle = shuffle
self.mode = mode
self.epoch = -1
self.rank = 0
self.world_size = 1
self.worker_id = 0
self.num_workers = 1
def set_epoch(self, epoch):
self.epoch = epoch
def get_rank_data_list(self, data_index):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
if self.mode == "train":
if self.shuffle:
random.seed(self.epoch)
random.shuffle(data_index)
return data_index[self.rank::self.world_size]
return data_index
def get_worker_data_list(self, rank_data_index):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return rank_data_index[self.worker_id::self.num_workers]
def close_reader(self, reader_list):
for reader in reader_list:
reader.close()
def __iter__(self):
data_index = list(range(len(self.scp_lists)))
rank_data_index = self.get_rank_data_list(data_index)
worker_data_index = self.get_worker_data_list(rank_data_index)
for index in worker_data_index:
data = dict(scp=self.scp_lists[index])
assert 'scp' in data
scp = data['scp']
data_file_list = scp.strip().split()
data_name_list = self.data_names.split(",")
data_type_list = self.data_types.split(",")
for file in data_file_list:
assert os.path.exists(file), "{} not exists".format(file)
assert len(data_file_list) == len(data_name_list) == len(data_type_list), \
"The item number of data, data_names, data_types must be the same "
reader_list = []
for data_file, data_type in zip(data_file_list, data_type_list):
if data_type == "kaldi_ark":
ark_reader = ReadHelper('ark:{}'.format(data_file))
reader_list.append(ark_reader)
elif data_type == "text" or data_type == "sound":
text_reader = open(data_file, "r")
reader_list.append(text_reader)
elif data_type == "none":
continue
else:
raise TypeError("Data type {} is not supported".format(data_type))
for items in zip(*reader_list):
sample_dict = {}
for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
if data_type == "kaldi_ark":
key, mat = item
sample_dict[data_name] = mat
if data_name == "speech":
sample_dict["key"] = key
elif data_type == "sound":
key, path = item.strip().split()
waveform, sampling_rate = torchaudio.load(path)
if self.frontend_conf is not None:
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend_conf["fs"])(waveform)
sampling_rate = self.frontend_conf["fs"]
waveform = waveform.numpy()
mat = waveform[0]
sample_dict[data_name] = mat
sample_dict["sampling_rate"] = sampling_rate
if data_name == "speech":
sample_dict["key"] = key
else:
text = item
segs = text.strip().split()
sample_dict[data_name] = segs[1:]
if "key" not in sample_dict:
sample_dict["key"] = segs[0]
yield sample_dict
self.close_reader(reader_list)
def len_fn_example(data):
return 1
def len_fn_token(data):
assert "speech" in data
if "sampling_rate" in data:
return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
else:
return data["speech"].shape[0]
def Dataset(data_list_file,
dict,
seg_dict,
punc_dict,
bpe_tokenizer,
conf,
frontend_conf,
mode="train",
batch_mode="padding"):
scp_lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
data_names = conf.get("data_names", "speech,text")
data_types = conf.get("data_types", "kaldi_ark,text")
dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode)
filter_conf = conf.get('filter_conf', {})
filter_fn = partial(filter, **filter_conf)
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
if "text" in data_names:
vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
tokenize_fn = partial(tokenize, **vocab)
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
if shuffle:
buffer_conf = conf.get('shuffle_conf', {})
buffer_size = buffer_conf['shuffle_size']
sort_size = buffer_conf['sort_size']
else:
buffer_size = 0
sort_size = 1
batch_conf = conf.get('batch_conf', {})
batch_size = batch_conf['batch_size']
batch_type = batch_conf['batch_type']
assert batch_type in ["example", "token"]
if batch_type == 'example':
len_fn = len_fn_example
else:
len_fn = len_fn_token
dataset = MaxTokenBucketizerIterDataPipe(dataset,
batch_size=batch_size,
len_fn=len_fn,
buffer_size=buffer_size,
sort_size=sort_size,
batch_mode=batch_mode)
int_pad_value = conf.get("int_pad_value", -1)
float_pad_value = conf.get("float_pad_value", 0.0)
padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
padding_fn = partial(padding, **padding_conf)
dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
return dataset

View File

@@ -0,0 +1,40 @@
import numpy as np
import torch
from funasr_local.datasets.collate_fn import crop_to_max_size
def clipping(data):
assert isinstance(data, list)
assert "key" in data[0]
keys = [x["key"] for x in data]
batch = {}
data_names = data[0].keys()
for data_name in data_names:
if data_name == "key":
continue
else:
if data[0][data_name].dtype.kind == "i":
tensor_type = torch.int64
else:
tensor_type = torch.float32
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
length_clip = min(tensor_lengths)
tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
diff = length - length_clip
assert diff >= 0
if diff == 0:
tensor_clip[i] = tensor
else:
tensor_clip[i] = crop_to_max_size(tensor, length_clip)
batch[data_name] = tensor_clip
batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
return keys, batch

View File

@@ -0,0 +1,26 @@
#!/usr/bin/env python
def filter(data,
speech_length_min=100,
speech_length_max=15000,
token_length_min=0,
token_length_max=200):
assert "speech" in data or "text" in data
if "speech" in data and "text" in data:
if "sampling_rate" in data:
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
else:
speech_length = data["speech"].shape[0]
num_tokens = len(data['text'])
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
elif "speech" in data:
if "sampling_rate" in data:
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
else:
speech_length = data["speech"].shape[0]
return speech_length_min < speech_length < speech_length_max
else:
num_tokens = len(data['text'])
return token_length_min < num_tokens < token_length_max

View File

@@ -0,0 +1,30 @@
import numpy as np
def build_LFR_features(data, m, n):
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
Args:
inputs_batch: inputs is T x D np.ndarray
m: number of frames to stack
n: number of frames to skip
"""
LFR_inputs = []
T = data.shape[0]
T_lfr = int(np.ceil(T / n))
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(data[i*n:i*n+m]))
else:
num_padding = m - (T - i * n)
frame = np.hstack(data[i*n:])
for _ in range(num_padding):
frame = np.hstack((frame, data[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)

View File

@@ -0,0 +1,34 @@
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
def padding(data, float_pad_value=0.0, int_pad_value=-1):
assert isinstance(data, list)
assert "key" in data[0]
assert "speech" in data[0] or "text" in data[0]
keys = [x["key"] for x in data]
batch = {}
data_names = data[0].keys()
for data_name in data_names:
if data_name == "key" or data_name =="sampling_rate":
continue
else:
if data[0][data_name].dtype.kind == "i":
pad_value = int_pad_value
tensor_type = torch.int64
else:
pad_value = float_pad_value
tensor_type = torch.float32
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
tensor_pad = pad_sequence(tensor_list,
batch_first=True,
padding_value=pad_value)
batch[data_name] = tensor_pad
batch[data_name + "_lengths"] = tensor_lengths
return keys, batch

View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python
import re
import numpy as np
def forward_segment(text, seg_dict):
word_list = []
i = 0
while i < len(text):
longest_word = text[i]
for j in range(i + 1, len(text) + 1):
word = text[i:j]
if word in seg_dict:
if len(word) > len(longest_word):
longest_word = word
word_list.append(longest_word)
i += len(longest_word)
return word_list
def seg_tokenize(txt, seg_dict):
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
out_txt = ""
for word in txt:
word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
if pattern.match(word):
for char in word:
if char in seg_dict:
out_txt += seg_dict[char] + " "
else:
out_txt += "<unk>" + " "
else:
out_txt += "<unk>" + " "
return out_txt.strip().split()
def tokenize(data,
vocab=None,
seg_dict=None,
punc_dict=None,
bpe_tokenizer=None):
assert "text" in data
assert isinstance(vocab, dict)
text = data["text"]
token = []
vad = -2
if bpe_tokenizer is not None:
text = bpe_tokenizer.text2tokens("".join(text))
if seg_dict is not None:
assert isinstance(seg_dict, dict)
text = seg_tokenize(text, seg_dict)
length = len(text)
for i in range(length):
x = text[i]
if i == length-1 and "punc" in data and x.startswith("vad:"):
vad = x[4:]
if len(vad) == 0:
vad = -1
else:
vad = int(vad)
elif x in vocab:
token.append(vocab[x])
else:
token.append(vocab['<unk>'])
if "punc" in data and punc_dict is not None:
punc_token = []
for punc in data["punc"]:
if punc in punc_dict:
punc_token.append(punc_dict[punc])
else:
punc_token.append(punc_dict["_"])
data["punc"] = np.array(punc_token)
data["text"] = np.array(token)
if vad is not -2:
data["vad_indexes"]=np.array([vad], dtype=np.int64)
return data

View File

@@ -0,0 +1,33 @@
import os
class MsDataset(object):
@classmethod
def load_core(cls, data_dir, data_set):
wav_file = os.path.join(data_dir, data_set, "wav.scp")
text_file = os.path.join(data_dir, data_set, "text")
with open(wav_file) as f:
wav_lines = f.readlines()
with open(text_file) as f:
text_lines = f.readlines()
data_list = []
for wav_line, text_line in zip(wav_lines, text_lines):
item = {}
item["Audio:FILE"] = wav_line.strip().split()[-1]
item["Text:LABEL"] = " ".join(text_line.strip().split()[1:])
data_list.append(item)
return data_list
@classmethod
def load(cls, dataset_name, namespace="speech_asr", train_set="train", dev_set="validation"):
if os.path.exists(dataset_name):
data_dir = dataset_name
ds_dict = {}
ds_dict["train"] = cls.load_core(data_dir, train_set)
ds_dict["validation"] = cls.load_core(data_dir, dev_set)
ds_dict["raw_data_dir"] = data_dir
return ds_dict
else:
from modelscope.msdatasets import MsDataset
ds_dict = MsDataset.load(dataset_name=dataset_name, namespace=namespace)
return ds_dict

View File

@@ -0,0 +1,824 @@
import re
from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Collection
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
import scipy.signal
import soundfile
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.text.build_tokenizer import build_tokenizer
from funasr_local.text.cleaner import TextCleaner
from funasr_local.text.token_id_converter import TokenIDConverter
class AbsPreprocessor(ABC):
def __init__(self, train: bool):
self.train = train
@abstractmethod
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
raise NotImplementedError
def forward_segment(text, dic):
word_list = []
i = 0
while i < len(text):
longest_word = text[i]
for j in range(i + 1, len(text) + 1):
word = text[i:j]
if word in dic:
if len(word) > len(longest_word):
longest_word = word
word_list.append(longest_word)
i += len(longest_word)
return word_list
def seg_tokenize(txt, seg_dict):
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
out_txt = ""
for word in txt:
word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
if pattern.match(word):
for char in word:
if char in seg_dict:
out_txt += seg_dict[char] + " "
else:
out_txt += "<unk>" + " "
else:
out_txt += "<unk>" + " "
return out_txt.strip().split()
def seg_tokenize_wo_pattern(txt, seg_dict):
out_txt = ""
for word in txt:
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
out_txt += "<unk>" + " "
return out_txt.strip().split()
def framing(
x,
frame_length: int = 512,
frame_shift: int = 256,
centered: bool = True,
padded: bool = True,
):
if x.size == 0:
raise ValueError("Input array size is zero")
if frame_length < 1:
raise ValueError("frame_length must be a positive integer")
if frame_length > x.shape[-1]:
raise ValueError("frame_length is greater than input length")
if 0 >= frame_shift:
raise ValueError("frame_shift must be greater than 0")
if centered:
pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
(frame_length // 2, frame_length // 2)
]
x = np.pad(x, pad_shape, mode="constant", constant_values=0)
if padded:
# Pad to integer number of windowed segments
# I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
# with integer nseg
nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
x = np.pad(x, pad_shape, mode="constant", constant_values=0)
# Created strided array of data segments
if frame_length == 1 and frame_length == frame_shift:
result = x[..., None]
else:
shape = x.shape[:-1] + (
(x.shape[-1] - frame_length) // frame_shift + 1,
frame_length,
)
strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
return result
def detect_non_silence(
x: np.ndarray,
threshold: float = 0.01,
frame_length: int = 1024,
frame_shift: int = 512,
window: str = "boxcar",
) -> np.ndarray:
"""Power based voice activity detection.
Args:
x: (Channel, Time)
>>> x = np.random.randn(1000)
>>> detect = detect_non_silence(x)
>>> assert x.shape == detect.shape
>>> assert detect.dtype == np.bool
"""
if x.shape[-1] < frame_length:
return np.full(x.shape, fill_value=True, dtype=np.bool)
if x.dtype.kind == "i":
x = x.astype(np.float64)
# framed_w: (C, T, F)
framed_w = framing(
x,
frame_length=frame_length,
frame_shift=frame_shift,
centered=False,
padded=True,
)
framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
# power: (C, T)
power = (framed_w ** 2).mean(axis=-1)
# mean_power: (C, 1)
mean_power = np.mean(power, axis=-1, keepdims=True)
if np.all(mean_power == 0):
return np.full(x.shape, fill_value=True, dtype=np.bool)
# detect_frames: (C, T)
detect_frames = power / mean_power > threshold
# detects: (C, T, F)
detects = np.broadcast_to(
detect_frames[..., None], detect_frames.shape + (frame_shift,)
)
# detects: (C, TF)
detects = detects.reshape(*detect_frames.shape[:-1], -1)
# detects: (C, TF)
return np.pad(
detects,
[(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
mode="edge",
)
class CommonPreprocessor(AbsPreprocessor):
def __init__(
self,
train: bool,
token_type: str = None,
token_list: Union[Path, str, Iterable[str]] = None,
bpemodel: Union[Path, str, Iterable[str]] = None,
text_cleaner: Collection[str] = None,
g2p_type: str = None,
unk_symbol: str = "<unk>",
space_symbol: str = "<space>",
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
delimiter: str = None,
rir_scp: str = None,
rir_apply_prob: float = 1.0,
noise_scp: str = None,
noise_apply_prob: float = 1.0,
noise_db_range: str = "3_10",
speech_volume_normalize: float = None,
speech_name: str = "speech",
text_name: str = "text",
split_with_space: bool = False,
seg_dict_file: str = None,
):
super().__init__(train)
self.train = train
self.speech_name = speech_name
self.text_name = text_name
self.speech_volume_normalize = speech_volume_normalize
self.rir_apply_prob = rir_apply_prob
self.noise_apply_prob = noise_apply_prob
self.split_with_space = split_with_space
self.seg_dict = None
if seg_dict_file is not None:
self.seg_dict = {}
with open(seg_dict_file) as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
self.seg_dict[key] = " ".join(value)
if token_type is not None:
if token_list is None:
raise ValueError("token_list is required if token_type is not None")
self.text_cleaner = TextCleaner(text_cleaner)
self.tokenizer = build_tokenizer(
token_type=token_type,
bpemodel=bpemodel,
delimiter=delimiter,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
g2p_type=g2p_type,
)
self.token_id_converter = TokenIDConverter(
token_list=token_list,
unk_symbol=unk_symbol,
)
else:
self.text_cleaner = None
self.tokenizer = None
self.token_id_converter = None
if train and rir_scp is not None:
self.rirs = []
with open(rir_scp, "r", encoding="utf-8") as f:
for line in f:
sps = line.strip().split(None, 1)
if len(sps) == 1:
self.rirs.append(sps[0])
else:
self.rirs.append(sps[1])
else:
self.rirs = None
if train and noise_scp is not None:
self.noises = []
with open(noise_scp, "r", encoding="utf-8") as f:
for line in f:
sps = line.strip().split(None, 1)
if len(sps) == 1:
self.noises.append(sps[0])
else:
self.noises.append(sps[1])
sps = noise_db_range.split("_")
if len(sps) == 1:
self.noise_db_low, self.noise_db_high = float(sps[0])
elif len(sps) == 2:
self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
else:
raise ValueError(
"Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
)
else:
self.noises = None
def _speech_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, Union[str, np.ndarray]]:
assert check_argument_types()
if self.speech_name in data:
if self.train and (self.rirs is not None or self.noises is not None):
speech = data[self.speech_name]
nsamples = len(speech)
# speech: (Nmic, Time)
if speech.ndim == 1:
speech = speech[None, :]
else:
speech = speech.T
# Calc power on non shlence region
power = (speech[detect_non_silence(speech)] ** 2).mean()
# 1. Convolve RIR
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs)
if rir_path is not None:
rir, _ = soundfile.read(
rir_path, dtype=np.float64, always_2d=True
)
# rir: (Nmic, Time)
rir = rir.T
# speech: (Nmic, Time)
# Note that this operation doesn't change the signal length
speech = scipy.signal.convolve(speech, rir, mode="full")[
:, : speech.shape[1]
]
# Reverse mean power to the original power
power2 = (speech[detect_non_silence(speech)] ** 2).mean()
speech = np.sqrt(power / max(power2, 1e-10)) * speech
# 2. Add Noise
if (
self.noises is not None
and self.noise_apply_prob >= np.random.random()
):
noise_path = np.random.choice(self.noises)
if noise_path is not None:
noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high
)
with soundfile.SoundFile(noise_path) as f:
if f.frames == nsamples:
noise = f.read(dtype=np.float64, always_2d=True)
elif f.frames < nsamples:
offset = np.random.randint(0, nsamples - f.frames)
# noise: (Time, Nmic)
noise = f.read(dtype=np.float64, always_2d=True)
# Repeat noise
noise = np.pad(
noise,
[(offset, nsamples - f.frames - offset), (0, 0)],
mode="wrap",
)
else:
offset = np.random.randint(0, f.frames - nsamples)
f.seek(offset)
# noise: (Time, Nmic)
noise = f.read(
nsamples, dtype=np.float64, always_2d=True
)
if len(noise) != nsamples:
raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time)
noise = noise.T
noise_power = (noise ** 2).mean()
scale = (
10 ** (-noise_db / 20)
* np.sqrt(power)
/ np.sqrt(max(noise_power, 1e-10))
)
speech = speech + scale * noise
speech = speech.T
ma = np.max(np.abs(speech))
if ma > 1.0:
speech /= ma
data[self.speech_name] = speech
if self.speech_volume_normalize is not None:
speech = data[self.speech_name]
ma = np.max(np.abs(speech))
data[self.speech_name] = speech * self.speech_volume_normalize / ma
assert check_return_type(data)
return data
def _text_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
if self.text_name in data and self.tokenizer is not None:
text = data[self.text_name]
text = self.text_cleaner(text)
if self.split_with_space:
tokens = text.strip().split(" ")
if self.seg_dict is not None:
tokens = seg_tokenize(tokens, self.seg_dict)
else:
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
assert check_argument_types()
data = self._speech_process(data)
data = self._text_process(data)
return data
## FIXME
class LMPreprocessor(CommonPreprocessor):
def __init__(
self,
train: bool,
token_type: str = None,
token_list: Union[Path, str, Iterable[str]] = None,
bpemodel: Union[Path, str, Iterable[str]] = None,
text_cleaner: Collection[str] = None,
g2p_type: str = None,
unk_symbol: str = "<unk>",
space_symbol: str = "<space>",
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
delimiter: str = None,
rir_scp: str = None,
rir_apply_prob: float = 1.0,
noise_scp: str = None,
noise_apply_prob: float = 1.0,
noise_db_range: str = "3_10",
speech_volume_normalize: float = None,
speech_name: str = "speech",
text_name: str = "text",
split_with_space: bool = False,
seg_dict_file: str = None,
):
super().__init__(train,
token_type,
token_list,
bpemodel,
text_cleaner,
g2p_type,
unk_symbol,
space_symbol,
non_linguistic_symbols,
delimiter,
rir_scp,
rir_apply_prob,
noise_scp,
noise_apply_prob,
noise_db_range,
speech_volume_normalize,
speech_name,
text_name,
split_with_space,
seg_dict_file,
)
def _text_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
if self.text_name in data and self.tokenizer is not None:
text = data[self.text_name]
text = self.text_cleaner(text)
if self.split_with_space:
tokens = text.strip().split(" ")
if self.seg_dict is not None:
tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
else:
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
class CommonPreprocessor_multi(AbsPreprocessor):
def __init__(
self,
train: bool,
token_type: str = None,
token_list: Union[Path, str, Iterable[str]] = None,
bpemodel: Union[Path, str, Iterable[str]] = None,
text_cleaner: Collection[str] = None,
g2p_type: str = None,
unk_symbol: str = "<unk>",
space_symbol: str = "<space>",
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
delimiter: str = None,
speech_name: str = "speech",
text_name: List[str] = ["text"],
):
super().__init__(train)
self.train = train
self.speech_name = speech_name
self.text_name = text_name
if token_type is not None:
if token_list is None:
raise ValueError("token_list is required if token_type is not None")
self.text_cleaner = TextCleaner(text_cleaner)
self.tokenizer = build_tokenizer(
token_type=token_type,
bpemodel=bpemodel,
delimiter=delimiter,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
g2p_type=g2p_type,
)
self.token_id_converter = TokenIDConverter(
token_list=token_list,
unk_symbol=unk_symbol,
)
else:
self.text_cleaner = None
self.tokenizer = None
self.token_id_converter = None
def _text_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
for text_n in self.text_name:
if text_n in data and self.tokenizer is not None:
text = data[text_n]
text = self.text_cleaner(text)
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[text_n] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
assert check_argument_types()
if self.speech_name in data:
# Nothing now: candidates:
# - STFT
# - Fbank
# - CMVN
# - Data augmentation
pass
data = self._text_process(data)
return data
class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
def __init__(
self,
train: bool,
token_type: List[str] = [None],
token_list: List[Union[Path, str, Iterable[str]]] = [None],
bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
text_cleaner: Collection[str] = None,
g2p_type: str = None,
unk_symbol: str = "<unk>",
space_symbol: str = "<space>",
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
delimiter: str = None,
rir_scp: str = None,
rir_apply_prob: float = 1.0,
noise_scp: str = None,
noise_apply_prob: float = 1.0,
noise_db_range: str = "3_10",
speech_volume_normalize: float = None,
speech_name: str = "speech",
text_name: List[str] = ["text"],
):
# TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
super().__init__(
train=train,
token_type=token_type[0],
token_list=token_list[0],
bpemodel=bpemodel[0],
text_cleaner=text_cleaner,
g2p_type=g2p_type,
unk_symbol=unk_symbol,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
delimiter=delimiter,
speech_name=speech_name,
text_name=text_name[0],
rir_scp=rir_scp,
rir_apply_prob=rir_apply_prob,
noise_scp=noise_scp,
noise_apply_prob=noise_apply_prob,
noise_db_range=noise_db_range,
speech_volume_normalize=speech_volume_normalize,
)
assert (
len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
), "token_type, token_list, bpemodel, or processing text_name mismatched"
self.num_tokenizer = len(token_type)
self.tokenizer = []
self.token_id_converter = []
for i in range(self.num_tokenizer):
if token_type[i] is not None:
if token_list[i] is None:
raise ValueError("token_list is required if token_type is not None")
self.tokenizer.append(
build_tokenizer(
token_type=token_type[i],
bpemodel=bpemodel[i],
delimiter=delimiter,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
g2p_type=g2p_type,
)
)
self.token_id_converter.append(
TokenIDConverter(
token_list=token_list[i],
unk_symbol=unk_symbol,
)
)
else:
self.tokenizer.append(None)
self.token_id_converter.append(None)
self.text_cleaner = TextCleaner(text_cleaner)
self.text_name = text_name # override the text_name from CommonPreprocessor
def _text_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
for i in range(self.num_tokenizer):
text_name = self.text_name[i]
if text_name in data and self.tokenizer[i] is not None:
text = data[text_name]
text = self.text_cleaner(text)
tokens = self.tokenizer[i].text2tokens(text)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
def __init__(
self,
train: bool,
token_type: str = None,
token_list: Union[Path, str, Iterable[str]] = None,
bpemodel: Union[Path, str, Iterable[str]] = None,
text_cleaner: Collection[str] = None,
g2p_type: str = None,
unk_symbol: str = "<unk>",
space_symbol: str = "<space>",
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
delimiter: str = None,
rir_scp: str = None,
rir_apply_prob: float = 1.0,
noise_scp: str = None,
noise_apply_prob: float = 1.0,
noise_db_range: str = "3_10",
speech_volume_normalize: float = None,
speech_name: str = "speech",
text_name: str = "text",
split_text_name: str = "split_text",
split_with_space: bool = False,
seg_dict_file: str = None,
):
super().__init__(
train=train,
# Force to use word.
token_type="word",
token_list=token_list,
bpemodel=bpemodel,
text_cleaner=text_cleaner,
g2p_type=g2p_type,
unk_symbol=unk_symbol,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
delimiter=delimiter,
speech_name=speech_name,
text_name=text_name,
rir_scp=rir_scp,
rir_apply_prob=rir_apply_prob,
noise_scp=noise_scp,
noise_apply_prob=noise_apply_prob,
noise_db_range=noise_db_range,
speech_volume_normalize=speech_volume_normalize,
split_with_space=split_with_space,
seg_dict_file=seg_dict_file,
)
# The data field name for split text.
self.split_text_name = split_text_name
@classmethod
def split_words(cls, text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def __call__(
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
) -> Dict[str, Union[list, np.ndarray]]:
assert check_argument_types()
# Split words.
if isinstance(data[self.text_name], str):
split_text = self.split_words(data[self.text_name])
else:
split_text = data[self.text_name]
data[self.text_name] = " ".join(split_text)
data = self._speech_process(data)
data = self._text_process(data)
data[self.split_text_name] = split_text
return data
def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
result = data[self.split_text_name]
del data[self.split_text_name]
return result
class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
def __init__(
self,
train: bool,
token_type: List[str] = [None],
token_list: List[Union[Path, str, Iterable[str]]] = [None],
bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
text_cleaner: Collection[str] = None,
g2p_type: str = None,
unk_symbol: str = "<unk>",
space_symbol: str = "<space>",
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
delimiter: str = None,
rir_scp: str = None,
rir_apply_prob: float = 1.0,
noise_scp: str = None,
noise_apply_prob: float = 1.0,
noise_db_range: str = "3_10",
speech_volume_normalize: float = None,
speech_name: str = "speech",
text_name: List[str] = ["text"],
vad_name: str = "vad_indexes",
):
# TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
super().__init__(
train=train,
token_type=token_type[0],
token_list=token_list[0],
bpemodel=bpemodel[0],
text_cleaner=text_cleaner,
g2p_type=g2p_type,
unk_symbol=unk_symbol,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
delimiter=delimiter,
speech_name=speech_name,
text_name=text_name[0],
rir_scp=rir_scp,
rir_apply_prob=rir_apply_prob,
noise_scp=noise_scp,
noise_apply_prob=noise_apply_prob,
noise_db_range=noise_db_range,
speech_volume_normalize=speech_volume_normalize,
)
assert (
len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
), "token_type, token_list, bpemodel, or processing text_name mismatched"
self.num_tokenizer = len(token_type)
self.tokenizer = []
self.token_id_converter = []
for i in range(self.num_tokenizer):
if token_type[i] is not None:
if token_list[i] is None:
raise ValueError("token_list is required if token_type is not None")
self.tokenizer.append(
build_tokenizer(
token_type=token_type[i],
bpemodel=bpemodel[i],
delimiter=delimiter,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
g2p_type=g2p_type,
)
)
self.token_id_converter.append(
TokenIDConverter(
token_list=token_list[i],
unk_symbol=unk_symbol,
)
)
else:
self.tokenizer.append(None)
self.token_id_converter.append(None)
self.text_cleaner = TextCleaner(text_cleaner)
self.text_name = text_name # override the text_name from CommonPreprocessor
self.vad_name = vad_name
def _text_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
for i in range(self.num_tokenizer):
text_name = self.text_name[i]
#import pdb; pdb.set_trace()
if text_name in data and self.tokenizer[i] is not None:
text = data[text_name]
text = self.text_cleaner(text)
tokens = self.tokenizer[i].text2tokens(text)
if "vad:" in tokens[-1]:
vad = tokens[-1][4:]
tokens = tokens[:-1]
if len(vad) == 0:
vad = -1
else:
vad = int(vad)
data[self.vad_name] = np.array([vad], dtype=np.int64)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
return data
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences

View File

@@ -0,0 +1,93 @@
# Export models
## Environments
### Install modelscope and funasr
The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation)
```shell
# pip3 install torch torchaudio
pip install -U modelscope funasr
# For the users in China, you could install with the command:
# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### Install the quantization tools
```shell
pip install torch-quant # Optional, for torchscript quantization
pip install onnx onnxruntime # Optional, for onnx quantization
```
## Usage
`Tips`: torch>=1.11.0
```shell
python -m funasr.export.export_model \
--model-name [model_name] \
--export-dir [export_dir] \
--type [onnx, torch] \
--quantize [true, false] \
--fallback-num [fallback_num]
```
`model-name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb).
`export-dir`: the dir where the onnx is export.
`type`: `onnx` or `torch`, export onnx format model or torchscript format model.
`quantize`: `true`, export quantized model at the same time; `false`, export fp32 model only.
`fallback-num`: specify the number of fallback layers to perform automatic mixed precision quantization.
### Export onnx format model
#### Export model from modelscope
```shell
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize false
```
#### Export model from local path
The model'name must be `model.pb`
```shell
python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize false
```
#### Test onnx model
Ref to [test](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export/test)
### Export torchscripts format model
#### Export model from modelscope
```shell
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize false
```
#### Export model from local path
The model'name must be `model.pb`
```shell
python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize false
```
#### Test onnx model
Ref to [test](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export/test)
## Runtime
### ONNXRuntime
#### ONNXRuntime-python
Ref to [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/onnxruntime_python.html)
#### ONNXRuntime-cpp
Ref to [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/onnxruntime_cpp.html)
### Libtorch
#### Libtorch-python
Ref to [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/libtorch_python.html)
#### Libtorch-cpp
Undo
## Performance Benchmark
### Paraformer on CPU
[onnx runtime](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/python/benchmark_onnx.md)
[libtorch runtime](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/python/benchmark_libtorch.md)
### Paraformer on GPU
[nv-triton](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/triton_gpu)
## Acknowledge
Torch model quantization is supported by [BladeDISC](https://github.com/alibaba/BladeDISC), an end-to-end DynamIc Shape Compiler project for machine learning workloads. BladeDISC provides general, transparent, and ease of use performance optimization for TensorFlow/PyTorch workloads on GPGPU and CPU backends. If you are interested, please contact us.

View File

View File

@@ -0,0 +1,285 @@
import json
from typing import Union, Dict
from pathlib import Path
from typeguard import check_argument_types
import os
import logging
import torch
from funasr_local.export.models import get_model
import numpy as np
import random
from funasr_local.utils.types import str2bool
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
class ModelExport:
def __init__(
self,
cache_dir: Union[Path, str] = None,
onnx: bool = True,
device: str = "cpu",
quant: bool = True,
fallback_num: int = 0,
audio_in: str = None,
calib_num: int = 200,
):
assert check_argument_types()
self.set_all_random_seed(0)
if cache_dir is None:
cache_dir = Path.home() / ".cache" / "export"
self.cache_dir = Path(cache_dir)
self.export_config = dict(
feats_dim=560,
onnx=False,
)
print("output dir: {}".format(self.cache_dir))
self.onnx = onnx
self.device = device
self.quant = quant
self.fallback_num = fallback_num
self.frontend = None
self.audio_in = audio_in
self.calib_num = calib_num
def _export(
self,
model,
tag_name: str = None,
verbose: bool = False,
):
export_dir = self.cache_dir / tag_name.replace(' ', '-')
os.makedirs(export_dir, exist_ok=True)
# export encoder1
self.export_config["model_name"] = "model"
model = get_model(
model,
self.export_config,
)
model.eval()
# self._export_onnx(model, verbose, export_dir)
if self.onnx:
self._export_onnx(model, verbose, export_dir)
else:
self._export_torchscripts(model, verbose, export_dir)
print("output dir: {}".format(export_dir))
def _torch_quantize(self, model):
def _run_calibration_data(m):
# using dummy inputs for a example
if self.audio_in is not None:
feats, feats_len = self.load_feats(self.audio_in)
for i, (feat, len) in enumerate(zip(feats, feats_len)):
with torch.no_grad():
m(feat, len)
else:
dummy_input = model.get_dummy_inputs()
m(*dummy_input)
from torch_quant.module import ModuleFilter
from torch_quant.quantizer import Backend, Quantizer
from funasr_local.export.models.modules.decoder_layer import DecoderLayerSANM
from funasr_local.export.models.modules.encoder_layer import EncoderLayerSANM
module_filter = ModuleFilter(include_classes=[EncoderLayerSANM, DecoderLayerSANM])
module_filter.exclude_op_types = [torch.nn.Conv1d]
quantizer = Quantizer(
module_filter=module_filter,
backend=Backend.FBGEMM,
)
model.eval()
calib_model = quantizer.calib(model)
_run_calibration_data(calib_model)
if self.fallback_num > 0:
# perform automatic mixed precision quantization
amp_model = quantizer.amp(model)
_run_calibration_data(amp_model)
quantizer.fallback(amp_model, num=self.fallback_num)
print('Fallback layers:')
print('\n'.join(quantizer.module_filter.exclude_names))
quant_model = quantizer.quantize(model)
return quant_model
def _export_torchscripts(self, model, verbose, path, enc_size=None):
if enc_size:
dummy_input = model.get_dummy_inputs(enc_size)
else:
dummy_input = model.get_dummy_inputs()
if self.device == 'cuda':
model = model.cuda()
dummy_input = tuple([i.cuda() for i in dummy_input])
# model_script = torch.jit.script(model)
model_script = torch.jit.trace(model, dummy_input)
model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
if self.quant:
quant_model = self._torch_quantize(model)
model_script = torch.jit.trace(quant_model, dummy_input)
model_script.save(os.path.join(path, f'{model.model_name}_quant.torchscripts'))
def set_all_random_seed(self, seed: int):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
def parse_audio_in(self, audio_in):
wav_list, name_list = [], []
if audio_in.endswith(".scp"):
f = open(audio_in, 'r')
lines = f.readlines()[:self.calib_num]
for line in lines:
name, path = line.strip().split()
name_list.append(name)
wav_list.append(path)
else:
wav_list = [audio_in,]
name_list = ["test",]
return wav_list, name_list
def load_feats(self, audio_in: str = None):
import torchaudio
wav_list, name_list = self.parse_audio_in(audio_in)
feats = []
feats_len = []
for line in wav_list:
path = line.strip()
waveform, sampling_rate = torchaudio.load(path)
if sampling_rate != self.frontend.fs:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend.fs)(waveform)
fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
feats.append(fbank)
feats_len.append(fbank_len)
return feats, feats_len
def export(self,
tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
mode: str = None,
):
model_dir = tag_name
if model_dir.startswith('damo'):
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
if mode is None:
import json
json_file = os.path.join(model_dir, 'configuration.json')
with open(json_file, 'r') as f:
config_data = json.load(f)
if config_data['task'] == "punctuation":
mode = config_data['model']['punc_model_config']['mode']
else:
mode = config_data['model']['model_config']['mode']
if mode.startswith('paraformer'):
from funasr_local.tasks.asr import ASRTaskParaformer as ASRTask
config = os.path.join(model_dir, 'config.yaml')
model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
config, model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
elif mode.startswith('offline'):
from funasr_local.tasks.vad import VADTask
config = os.path.join(model_dir, 'vad.yaml')
model_file = os.path.join(model_dir, 'vad.pb')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
model, vad_infer_args = VADTask.build_model_from_file(
config, model_file, cmvn_file=cmvn_file, device='cpu'
)
self.export_config["feats_dim"] = 400
self.frontend = model.frontend
elif mode.startswith('punc'):
from funasr_local.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
elif mode.startswith('punc_VadRealtime'):
from funasr_local.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
self._export(model, tag_name)
def _export_onnx(self, model, verbose, path, enc_size=None):
if enc_size:
dummy_input = model.get_dummy_inputs(enc_size)
else:
dummy_input = model.get_dummy_inputs()
# model_script = torch.jit.script(model)
model_script = model #torch.jit.trace(model)
model_path = os.path.join(path, f'{model.model_name}.onnx')
torch.onnx.export(
model_script,
dummy_input,
model_path,
verbose=verbose,
opset_version=14,
input_names=model.get_input_names(),
output_names=model.get_output_names(),
dynamic_axes=model.get_dynamic_axes()
)
if self.quant:
from onnxruntime.quantization import QuantType, quantize_dynamic
import onnx
quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
onnx_model = onnx.load(model_path)
nodes = [n.name for n in onnx_model.graph.node]
nodes_to_exclude = [m for m in nodes if 'output' in m]
quantize_dynamic(
model_input=model_path,
model_output=quant_model_path,
op_types_to_quantize=['MatMul'],
per_channel=True,
reduce_range=False,
weight_type=QuantType.QUInt8,
nodes_to_exclude=nodes_to_exclude,
)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
args = parser.parse_args()
export_model = ModelExport(
cache_dir=args.export_dir,
onnx=args.type == 'onnx',
device=args.device,
quant=args.quantize,
fallback_num=args.fallback_num,
audio_in=args.audio_in,
calib_num=args.calib_num,
)
export_model.export(args.model_name)

View File

@@ -0,0 +1,162 @@
from typing import Tuple
import torch
import torch.nn as nn
from funasr_local.models.encoder.sanm_encoder import SANMEncoder
from funasr_local.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr_local.models.encoder.sanm_encoder import SANMVadEncoder
from funasr_local.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class CT_Transformer(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
self.decoder = model.decoder
# self.model = model
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def get_input_names(self):
return ['inputs', 'text_lengths']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'inputs': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
class CT_Transformer_VadRealtime(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
if isinstance(model.encoder, SANMVadEncoder):
self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
self.decoder = model.decoder
self.model_name = model_name
def forward(self, inputs: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: torch.Tensor,
sub_masks: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
y = self.decoder(h)
return y
def with_vad(self):
return True
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
text_lengths = torch.tensor([length], dtype=torch.int32)
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
sub_masks = torch.ones(length, length, dtype=torch.float32)
sub_masks = torch.tril(sub_masks).type(torch.float32)
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
def get_input_names(self):
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'inputs': {
1: 'feats_length'
},
'vad_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'sub_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'logits': {
1: 'logits_length'
},
}

View File

@@ -0,0 +1,25 @@
from funasr_local.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
from funasr_local.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr_local.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr_local.models.e2e_vad import E2EVadModel
from funasr_local.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr_local.models.target_delay_transformer import TargetDelayTransformer
from funasr_local.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
from funasr_local.train.abs_model import PunctuationModel
from funasr_local.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr_local.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
return BiCifParaformer_export(model, **export_config)
elif isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
elif isinstance(model, E2EVadModel):
return E2EVadModel_export(model, **export_config)
elif isinstance(model, PunctuationModel):
if isinstance(model.punc_model, TargetDelayTransformer):
return CT_Transformer_export(model.punc_model, **export_config)
elif isinstance(model.punc_model, VadRealtimeTransformer):
return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
else:
raise "Funasr does not support the given model type currently."

View File

@@ -0,0 +1,159 @@
import os
import torch
import torch.nn as nn
from funasr_local.export.utils.torch_function import MakePadMask
from funasr_local.export.utils.torch_function import sequence_mask
from funasr_local.modules.attention import MultiHeadedAttentionSANMDecoder
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
from funasr_local.modules.attention import MultiHeadedAttentionCrossAtt
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr_local.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
from funasr_local.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
class ParaformerSANMDecoder(nn.Module):
def __init__(self, model,
max_seq_len=512,
model_name='decoder',
onnx: bool = True,):
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
self.model = model
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
for i, d in enumerate(self.model.decoders):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
self.model.decoders[i] = DecoderLayerSANM_export(d)
if self.model.decoders2 is not None:
for i, d in enumerate(self.model.decoders2):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
self.model.decoders2[i] = DecoderLayerSANM_export(d)
for i, d in enumerate(self.model.decoders3):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
self.model.decoders3[i] = DecoderLayerSANM_export(d)
self.output_layer = model.output_layer
self.after_norm = model.after_norm
self.model_name = model_name
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
mask_4d_bhlt = 1 - mask[:, None, None, :]
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
):
tgt = ys_in_pad
tgt_mask = self.make_pad_mask(ys_in_lens)
tgt_mask, _ = self.prepare_mask(tgt_mask)
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = self.make_pad_mask(hlens)
_, memory_mask = self.prepare_mask(memory_mask)
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
x, tgt_mask, memory, memory_mask
)
if self.model.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
x, tgt_mask, memory, memory_mask
)
x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
x, tgt_mask, memory, memory_mask
)
x = self.after_norm(x)
x = self.output_layer(x)
return x, ys_in_lens
def get_dummy_inputs(self, enc_size):
tgt = torch.LongTensor([0]).unsqueeze(0)
memory = torch.randn(1, 100, enc_size)
pre_acoustic_embeds = torch.randn(1, 1, enc_size)
cache_num = len(self.model.decoders) + len(self.model.decoders2)
cache = [
torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
for _ in range(cache_num)
]
return (tgt, memory, pre_acoustic_embeds, cache)
def is_optimizable(self):
return True
def get_input_names(self):
cache_num = len(self.model.decoders) + len(self.model.decoders2)
return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ ['cache_%d' % i for i in range(cache_num)]
def get_output_names(self):
cache_num = len(self.model.decoders) + len(self.model.decoders2)
return ['y'] \
+ ['out_cache_%d' % i for i in range(cache_num)]
def get_dynamic_axes(self):
ret = {
'tgt': {
0: 'tgt_batch',
1: 'tgt_length'
},
'memory': {
0: 'memory_batch',
1: 'memory_length'
},
'pre_acoustic_embeds': {
0: 'acoustic_embeds_batch',
1: 'acoustic_embeds_length',
}
}
cache_num = len(self.model.decoders) + len(self.model.decoders2)
ret.update({
'cache_%d' % d: {
0: 'cache_%d_batch' % d,
2: 'cache_%d_length' % d
}
for d in range(cache_num)
})
return ret
def get_model_config(self, path):
return {
"dec_type": "XformerDecoder",
"model_path": os.path.join(path, f'{self.model_name}.onnx'),
"n_layers": len(self.model.decoders) + len(self.model.decoders2),
"odim": self.model.decoders[0].size
}

View File

@@ -0,0 +1,143 @@
import os
from funasr_local.export import models
import torch
import torch.nn as nn
from funasr_local.export.utils.torch_function import MakePadMask
from funasr_local.export.utils.torch_function import sequence_mask
from funasr_local.modules.attention import MultiHeadedAttentionSANMDecoder
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
from funasr_local.modules.attention import MultiHeadedAttentionCrossAtt, MultiHeadedAttention
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
from funasr_local.export.models.modules.multihead_att import OnnxMultiHeadedAttention
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr_local.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
from funasr_local.export.models.modules.decoder_layer import DecoderLayer as DecoderLayer_export
class ParaformerDecoderSAN(nn.Module):
def __init__(self, model,
max_seq_len=512,
model_name='decoder',
onnx: bool = True,):
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
self.model = model
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
for i, d in enumerate(self.model.decoders):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
# if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
# d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
if isinstance(d.src_attn, MultiHeadedAttention):
d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
self.model.decoders[i] = DecoderLayer_export(d)
self.output_layer = model.output_layer
self.after_norm = model.after_norm
self.model_name = model_name
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
mask_4d_bhlt = 1 - mask[:, None, None, :]
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
):
tgt = ys_in_pad
tgt_mask = self.make_pad_mask(ys_in_lens)
tgt_mask, _ = self.prepare_mask(tgt_mask)
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = self.make_pad_mask(hlens)
_, memory_mask = self.prepare_mask(memory_mask)
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask = self.model.decoders(
x, tgt_mask, memory, memory_mask
)
x = self.after_norm(x)
x = self.output_layer(x)
return x, ys_in_lens
def get_dummy_inputs(self, enc_size):
tgt = torch.LongTensor([0]).unsqueeze(0)
memory = torch.randn(1, 100, enc_size)
pre_acoustic_embeds = torch.randn(1, 1, enc_size)
cache_num = len(self.model.decoders) + len(self.model.decoders2)
cache = [
torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
for _ in range(cache_num)
]
return (tgt, memory, pre_acoustic_embeds, cache)
def is_optimizable(self):
return True
def get_input_names(self):
cache_num = len(self.model.decoders) + len(self.model.decoders2)
return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ ['cache_%d' % i for i in range(cache_num)]
def get_output_names(self):
cache_num = len(self.model.decoders) + len(self.model.decoders2)
return ['y'] \
+ ['out_cache_%d' % i for i in range(cache_num)]
def get_dynamic_axes(self):
ret = {
'tgt': {
0: 'tgt_batch',
1: 'tgt_length'
},
'memory': {
0: 'memory_batch',
1: 'memory_length'
},
'pre_acoustic_embeds': {
0: 'acoustic_embeds_batch',
1: 'acoustic_embeds_length',
}
}
cache_num = len(self.model.decoders) + len(self.model.decoders2)
ret.update({
'cache_%d' % d: {
0: 'cache_%d_batch' % d,
2: 'cache_%d_length' % d
}
for d in range(cache_num)
})
return ret
def get_model_config(self, path):
return {
"dec_type": "XformerDecoder",
"model_path": os.path.join(path, f'{self.model_name}.onnx'),
"n_layers": len(self.model.decoders) + len(self.model.decoders2),
"odim": self.model.decoders[0].size
}

View File

@@ -0,0 +1,219 @@
import logging
import torch
import torch.nn as nn
from funasr_local.export.utils.torch_function import MakePadMask
from funasr_local.export.utils.torch_function import sequence_mask
from funasr_local.models.encoder.sanm_encoder import SANMEncoder
from funasr_local.models.encoder.conformer_encoder import ConformerEncoder
from funasr_local.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr_local.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
from funasr_local.models.predictor.cif import CifPredictorV2, CifPredictorV3
from funasr_local.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
from funasr_local.export.models.predictor.cif import CifPredictorV3 as CifPredictorV3_export
from funasr_local.models.decoder.sanm_decoder import ParaformerSANMDecoder
from funasr_local.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr_local.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
from funasr_local.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
class Paraformer(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
elif isinstance(model.encoder, ConformerEncoder):
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
if isinstance(model.predictor, CifPredictorV2):
self.predictor = CifPredictorV2_export(model.predictor)
if isinstance(model.decoder, ParaformerSANMDecoder):
self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
elif isinstance(model.decoder, ParaformerDecoderSAN):
self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
self.feats_dim = feats_dim
self.model_name = model_name
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
# a. To device
batch = {"speech": speech, "speech_lengths": speech_lengths}
# batch = to_device(batch, device=self.device)
enc, enc_len = self.encoder(**batch)
mask = self.make_pad_mask(enc_len)[:, None, :]
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
pre_token_length = pre_token_length.floor().type(torch.int32)
decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out = torch.log_softmax(decoder_out, dim=-1)
# sample_ids = decoder_out.argmax(dim=-1)
return decoder_out, pre_token_length
def get_dummy_inputs(self):
speech = torch.randn(2, 30, self.feats_dim)
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
return (speech, speech_lengths)
def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
import numpy as np
fbank = np.loadtxt(txt_file)
fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
return (speech, speech_lengths)
def get_input_names(self):
return ['speech', 'speech_lengths']
def get_output_names(self):
return ['logits', 'token_num']
def get_dynamic_axes(self):
return {
'speech': {
0: 'batch_size',
1: 'feats_length'
},
'speech_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
class BiCifParaformer(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
elif isinstance(model.encoder, ConformerEncoder):
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
else:
logging.warning("Unsupported encoder type to export.")
if isinstance(model.predictor, CifPredictorV3):
self.predictor = CifPredictorV3_export(model.predictor)
else:
logging.warning("Wrong predictor type to export.")
if isinstance(model.decoder, ParaformerSANMDecoder):
self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
elif isinstance(model.decoder, ParaformerDecoderSAN):
self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
else:
logging.warning("Unsupported decoder type to export.")
self.feats_dim = feats_dim
self.model_name = model_name
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
# a. To device
batch = {"speech": speech, "speech_lengths": speech_lengths}
# batch = to_device(batch, device=self.device)
enc, enc_len = self.encoder(**batch)
mask = self.make_pad_mask(enc_len)[:, None, :]
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
pre_token_length = pre_token_length.round().type(torch.int32)
decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out = torch.log_softmax(decoder_out, dim=-1)
# get predicted timestamps
us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
return decoder_out, pre_token_length, us_alphas, us_cif_peak
def get_dummy_inputs(self):
speech = torch.randn(2, 30, self.feats_dim)
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
return (speech, speech_lengths)
def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
import numpy as np
fbank = np.loadtxt(txt_file)
fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
return (speech, speech_lengths)
def get_input_names(self):
return ['speech', 'speech_lengths']
def get_output_names(self):
return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
def get_dynamic_axes(self):
return {
'speech': {
0: 'batch_size',
1: 'feats_length'
},
'speech_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
'us_alphas': {
0: 'batch_size',
1: 'alphas_length'
},
'us_cif_peak': {
0: 'batch_size',
1: 'alphas_length'
},
}

View File

@@ -0,0 +1,60 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import torch
from torch import nn
import math
from funasr_local.models.encoder.fsmn_encoder import FSMN
from funasr_local.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
class E2EVadModel(nn.Module):
def __init__(self, model,
max_seq_len=512,
feats_dim=400,
model_name='model',
**kwargs,):
super(E2EVadModel, self).__init__()
self.feats_dim = feats_dim
self.max_seq_len = max_seq_len
self.model_name = model_name
if isinstance(model.encoder, FSMN):
self.encoder = FSMN_export(model.encoder)
else:
raise "unsupported encoder"
def forward(self, feats: torch.Tensor, *args, ):
scores, out_caches = self.encoder(feats, *args)
return scores, out_caches
def get_dummy_inputs(self, frame=30):
speech = torch.randn(1, frame, self.feats_dim)
in_cache0 = torch.randn(1, 128, 19, 1)
in_cache1 = torch.randn(1, 128, 19, 1)
in_cache2 = torch.randn(1, 128, 19, 1)
in_cache3 = torch.randn(1, 128, 19, 1)
return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
# def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
# import numpy as np
# fbank = np.loadtxt(txt_file)
# fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
# speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
# speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
# return (speech, speech_lengths)
def get_input_names(self):
return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
def get_output_names(self):
return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
def get_dynamic_axes(self):
return {
'speech': {
1: 'feats_length'
},
}

View File

@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
from funasr_local.export.utils.torch_function import MakePadMask
from funasr_local.export.utils.torch_function import sequence_mask
from funasr_local.modules.attention import MultiHeadedAttentionSANM
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
from funasr_local.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
from funasr_local.export.models.modules.encoder_layer import EncoderLayerConformer as EncoderLayerConformer_export
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr_local.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
from funasr_local.export.models.encoder.sanm_encoder import SANMEncoder
from funasr_local.modules.attention import RelPositionMultiHeadedAttention
# from funasr_local.export.models.modules.multihead_att import RelPositionMultiHeadedAttention as RelPositionMultiHeadedAttention_export
from funasr_local.export.models.modules.multihead_att import OnnxRelPosMultiHeadedAttention as RelPositionMultiHeadedAttention_export
class ConformerEncoder(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='encoder',
onnx: bool = True,
):
super().__init__()
self.embed = model.embed
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
for i, d in enumerate(self.model.encoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.self_attn, RelPositionMultiHeadedAttention):
d.self_attn = RelPositionMultiHeadedAttention_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders[i] = EncoderLayerConformer_export(d)
self.model_name = model_name
self.num_heads = model.encoders[0].self_attn.h
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask):
if len(mask.shape) == 2:
mask = 1 - mask[:, None, None, :]
elif len(mask.shape) == 3:
mask = 1 - mask[:, None, :]
return mask * -10000.0
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
mask = self.make_pad_mask(speech_lengths)
mask = self.prepare_mask(mask)
if self.embed is None:
xs_pad = speech
else:
xs_pad = self.embed(speech)
encoder_outs = self.model.encoders(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
xs_pad = self.model.after_norm(xs_pad)
return xs_pad, speech_lengths
def get_output_size(self):
return self.model.encoders[0].size
def get_dummy_inputs(self):
feats = torch.randn(1, 100, self.feats_dim)
return (feats)
def get_input_names(self):
return ['feats']
def get_output_names(self):
return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
def get_dynamic_axes(self):
return {
'feats': {
1: 'feats_length'
},
'encoder_out': {
1: 'enc_out_length'
},
'predictor_weight':{
1: 'pre_out_length'
}
}

View File

@@ -0,0 +1,296 @@
from typing import Tuple, Dict
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr_local.models.encoder.fsmn_encoder import BasicBlock
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, input):
output = self.linear(input)
return output
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, input):
output = self.linear(input)
return output
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, input):
out = self.relu(input)
return out
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
if self.rorder > 0:
self.conv_right = nn.Conv2d(
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
else:
self.conv_right = None
def forward(self, input: torch.Tensor, cache: torch.Tensor):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C
cache = cache.to(x_per.device)
y_left = torch.cat((cache, x_per), dim=2)
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left)
out = x_per + y_left
if self.conv_right is not None:
# maybe need to check
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
y_right = self.conv_right(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return output, cache
class BasicBlock_export(nn.Module):
def __init__(self,
model,
):
super(BasicBlock_export, self).__init__()
self.linear = model.linear
self.fsmn_block = model.fsmn_block
self.affine = model.affine
self.relu = model.relu
def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
x = self.linear(input) # B T D
# cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
# if cache_layer_name not in in_cache:
# in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
x, out_cache = self.fsmn_block(x, in_cache)
x = self.affine(x)
x = self.relu(x)
return x, out_cache
# class FsmnStack(nn.Sequential):
# def __init__(self, *args):
# super(FsmnStack, self).__init__(*args)
#
# def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
# x = input
# for module in self._modules.values():
# x = module(x, in_cache)
# return x
'''
FSMN net for keyword spotting
input_dim: input dimension
linear_dim: fsmn input dimensionll
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
num_syn: output dimension
fsmn_layers: no. of sequential fsmn layers
'''
class FSMN(nn.Module):
def __init__(
self, model,
):
super(FSMN, self).__init__()
# self.input_dim = input_dim
# self.input_affine_dim = input_affine_dim
# self.fsmn_layers = fsmn_layers
# self.linear_dim = linear_dim
# self.proj_dim = proj_dim
# self.output_affine_dim = output_affine_dim
# self.output_dim = output_dim
#
# self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
# self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
# self.relu = RectifiedLinear(linear_dim, linear_dim)
# self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
# range(fsmn_layers)])
# self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
# self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
# self.softmax = nn.Softmax(dim=-1)
self.in_linear1 = model.in_linear1
self.in_linear2 = model.in_linear2
self.relu = model.relu
# self.fsmn = model.fsmn
self.out_linear1 = model.out_linear1
self.out_linear2 = model.out_linear2
self.softmax = model.softmax
self.fsmn = model.fsmn
for i, d in enumerate(model.fsmn):
if isinstance(d, BasicBlock):
self.fsmn[i] = BasicBlock_export(d)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
*args,
):
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
"""
x = self.in_linear1(input)
x = self.in_linear2(x)
x = self.relu(x)
# x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
out_caches = list()
for i, d in enumerate(self.fsmn):
in_cache = args[i]
x, out_cache = d(x, in_cache)
out_caches.append(out_cache)
x = self.out_linear1(x)
x = self.out_linear2(x)
x = self.softmax(x)
return x, out_caches
'''
one deep fsmn layer
dimproj: projection dimension, input and output dimension of memory blocks
dimlinear: dimension of mapping layer
lorder: left order
rorder: right order
lstride: left stride
rstride: right stride
'''
class DFSMN(nn.Module):
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
super(DFSMN, self).__init__()
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.expand = AffineTransform(dimproj, dimlinear)
self.shrink = LinearTransform(dimlinear, dimproj)
self.conv_left = nn.Conv2d(
dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
else:
self.conv_right = None
def forward(self, input):
f1 = F.relu(self.expand(input))
p1 = self.shrink(f1)
x = torch.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
else:
out = x_per + self.conv_left(y_left)
out1 = out.permute(0, 3, 2, 1)
output = input + out1.squeeze(1)
return output
'''
build stacked dfsmn layers
'''
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
repeats = [
nn.Sequential(
DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())

View File

@@ -0,0 +1,213 @@
import torch
import torch.nn as nn
from funasr_local.export.utils.torch_function import MakePadMask
from funasr_local.export.utils.torch_function import sequence_mask
from funasr_local.modules.attention import MultiHeadedAttentionSANM
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
from funasr_local.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr_local.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
class SANMEncoder(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='encoder',
onnx: bool = True,
):
super().__init__()
self.embed = model.embed
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
if hasattr(model, 'encoders0'):
for i, d in enumerate(self.model.encoders0):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders0[i] = EncoderLayerSANM_export(d)
for i, d in enumerate(self.model.encoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders[i] = EncoderLayerSANM_export(d)
self.model_name = model_name
self.num_heads = model.encoders[0].self_attn.h
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
mask_4d_bhlt = 1 - mask[:, None, None, :]
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
mask = self.prepare_mask(mask)
if self.embed is None:
xs_pad = speech
else:
xs_pad = self.embed(speech)
encoder_outs = self.model.encoders0(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
encoder_outs = self.model.encoders(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.model.after_norm(xs_pad)
return xs_pad, speech_lengths
def get_output_size(self):
return self.model.encoders[0].size
def get_dummy_inputs(self):
feats = torch.randn(1, 100, self.feats_dim)
return (feats)
def get_input_names(self):
return ['feats']
def get_output_names(self):
return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
def get_dynamic_axes(self):
return {
'feats': {
1: 'feats_length'
},
'encoder_out': {
1: 'enc_out_length'
},
'predictor_weight':{
1: 'pre_out_length'
}
}
class SANMVadEncoder(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='encoder',
onnx: bool = True,
):
super().__init__()
self.embed = model.embed
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
if hasattr(model, 'encoders0'):
for i, d in enumerate(self.model.encoders0):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders0[i] = EncoderLayerSANM_export(d)
for i, d in enumerate(self.model.encoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders[i] = EncoderLayerSANM_export(d)
self.model_name = model_name
self.num_heads = model.encoders[0].self_attn.h
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask, sub_masks):
mask_3d_btd = mask[:, :, None]
mask_4d_bhlt = (1 - sub_masks) * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
vad_masks: torch.Tensor,
sub_masks: torch.Tensor,
):
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
vad_masks = self.prepare_mask(mask, vad_masks)
mask = self.prepare_mask(mask, sub_masks)
if self.embed is None:
xs_pad = speech
else:
xs_pad = self.embed(speech)
encoder_outs = self.model.encoders0(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
# encoder_outs = self.model.encoders(xs_pad, mask)
for layer_idx, encoder_layer in enumerate(self.model.encoders):
if layer_idx == len(self.model.encoders) - 1:
mask = vad_masks
encoder_outs = encoder_layer(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.model.after_norm(xs_pad)
return xs_pad, speech_lengths
def get_output_size(self):
return self.model.encoders[0].size
# def get_dummy_inputs(self):
# feats = torch.randn(1, 100, self.feats_dim)
# return (feats)
#
# def get_input_names(self):
# return ['feats']
#
# def get_output_names(self):
# return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
#
# def get_dynamic_axes(self):
# return {
# 'feats': {
# 1: 'feats_length'
# },
# 'encoder_out': {
# 1: 'enc_out_length'
# },
# 'predictor_weight': {
# 1: 'pre_out_length'
# }
#
# }

View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch import nn
class DecoderLayerSANM(nn.Module):
def __init__(
self,
model
):
super().__init__()
self.self_attn = model.self_attn
self.src_attn = model.src_attn
self.feed_forward = model.feed_forward
self.norm1 = model.norm1
self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
self.size = model.size
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
residual = tgt
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.self_attn is not None:
tgt = self.norm2(tgt)
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
x = residual + x
if self.src_attn is not None:
residual = x
x = self.norm3(x)
x = residual + self.src_attn(x, memory, memory_mask)
return x, tgt_mask, memory, memory_mask, cache
class DecoderLayer(nn.Module):
def __init__(self, model):
super().__init__()
self.self_attn = model.self_attn
self.src_attn = model.src_attn
self.feed_forward = model.feed_forward
self.norm1 = model.norm1
self.norm2 = model.norm2
self.norm3 = model.norm3
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
residual = tgt
tgt = self.norm1(tgt)
tgt_q = tgt
tgt_q_mask = tgt_mask
x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
residual = x
x = self.norm2(x)
x = residual + self.src_attn(x, memory, memory, memory_mask)
residual = x
x = self.norm3(x)
x = residual + self.feed_forward(x)
return x, tgt_mask, memory, memory_mask

View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch import nn
class EncoderLayerSANM(nn.Module):
def __init__(
self,
model,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = model.self_attn
self.feed_forward = model.feed_forward
self.norm1 = model.norm1
self.norm2 = model.norm2
self.in_size = model.in_size
self.size = model.size
def forward(self, x, mask):
residual = x
x = self.norm1(x)
x = self.self_attn(x, mask)
if self.in_size == self.size:
x = x + residual
residual = x
x = self.norm2(x)
x = self.feed_forward(x)
x = x + residual
return x, mask
class EncoderLayerConformer(nn.Module):
def __init__(
self,
model,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = model.self_attn
self.feed_forward = model.feed_forward
self.feed_forward_macaron = model.feed_forward_macaron
self.conv_module = model.conv_module
self.norm_ff = model.norm_ff
self.norm_mha = model.norm_mha
self.norm_ff_macaron = model.norm_ff_macaron
self.norm_conv = model.norm_conv
self.norm_final = model.norm_final
self.size = model.size
def forward(self, x, mask):
if isinstance(x, tuple):
x, pos_emb = x[0], x[1]
else:
x, pos_emb = x, None
if self.feed_forward_macaron is not None:
residual = x
x = self.norm_ff_macaron(x)
x = residual + self.feed_forward_macaron(x) * 0.5
residual = x
x = self.norm_mha(x)
x_q = x
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
x = residual + x_att
if self.conv_module is not None:
residual = x
x = self.norm_conv(x)
x = residual + self.conv_module(x)
residual = x
x = self.norm_ff(x)
x = residual + self.feed_forward(x) * 0.5
x = self.norm_final(x)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask

View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class PositionwiseFeedForward(nn.Module):
def __init__(self, model):
super().__init__()
self.w_1 = model.w_1
self.w_2 = model.w_2
self.activation = model.activation
def forward(self, x):
x = self.activation(self.w_1(x))
x = self.w_2(x)
return x
class PositionwiseFeedForwardDecoderSANM(nn.Module):
def __init__(self, model):
super().__init__()
self.w_1 = model.w_1
self.w_2 = model.w_2
self.activation = model.activation
self.norm = model.norm
def forward(self, x):
x = self.activation(self.w_1(x))
x = self.w_2(self.norm(x))
return x

View File

@@ -0,0 +1,243 @@
import os
import math
import torch
import torch.nn as nn
class MultiHeadedAttentionSANM(nn.Module):
def __init__(self, model):
super().__init__()
self.d_k = model.d_k
self.h = model.h
self.linear_out = model.linear_out
self.linear_q_k_v = model.linear_q_k_v
self.fsmn_block = model.fsmn_block
self.pad_fn = model.pad_fn
self.attn = None
self.all_head_size = self.h * self.d_k
def forward(self, x, mask):
mask_3d_btd, mask_4d_bhlt = mask
q_h, k_h, v_h, v = self.forward_qkv(x)
fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
q_h = q_h * self.d_k**(-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
return att_outs + fsmn_memory
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward_qkv(self, x):
q_k_v = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
q_h = self.transpose_for_scores(q)
k_h = self.transpose_for_scores(k)
v_h = self.transpose_for_scores(v)
return q_h, k_h, v_h, v
def forward_fsmn(self, inputs, mask):
# b, t, d = inputs.size()
# mask = torch.reshape(mask, (b, -1, 1))
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x = x + inputs
x = x * mask
return x
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.linear_out(context_layer) # (batch, time1, d_model)
def preprocess_for_attn(x, mask, cache, pad_fn):
x = x * mask
x = x.transpose(1, 2)
if cache is None:
x = pad_fn(x)
else:
x = torch.cat((cache[:, :, 1:], x), dim=2)
cache = x
return x, cache
torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
if torch_version >= (1, 8):
import torch.fx
torch.fx.wrap('preprocess_for_attn')
class MultiHeadedAttentionSANMDecoder(nn.Module):
def __init__(self, model):
super().__init__()
self.fsmn_block = model.fsmn_block
self.pad_fn = model.pad_fn
self.kernel_size = model.kernel_size
self.attn = None
def forward(self, inputs, mask, cache=None):
x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x = x + inputs
x = x * mask
return x, cache
class MultiHeadedAttentionCrossAtt(nn.Module):
def __init__(self, model):
super().__init__()
self.d_k = model.d_k
self.h = model.h
self.linear_q = model.linear_q
self.linear_k_v = model.linear_k_v
self.linear_out = model.linear_out
self.attn = None
self.all_head_size = self.h * self.d_k
def forward(self, x, memory, memory_mask):
q, k, v = self.forward_qkv(x, memory)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, memory_mask)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward_qkv(self, x, memory):
q = self.linear_q(x)
k_v = self.linear_k_v(memory)
k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
return q, k, v
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.linear_out(context_layer) # (batch, time1, d_model)
class OnnxMultiHeadedAttention(nn.Module):
def __init__(self, model):
super().__init__()
self.d_k = model.d_k
self.h = model.h
self.linear_q = model.linear_q
self.linear_k = model.linear_k
self.linear_v = model.linear_v
self.linear_out = model.linear_out
self.attn = None
self.all_head_size = self.h * self.d_k
def forward(self, query, key, value, mask):
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward_qkv(self, query, key, value):
q = self.linear_q(query)
k = self.linear_k(key)
v = self.linear_v(value)
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
return q, k, v
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.linear_out(context_layer) # (batch, time1, d_model)
class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
def __init__(self, model):
super().__init__(model)
self.linear_pos = model.linear_pos
self.pos_bias_u = model.pos_bias_u
self.pos_bias_v = model.pos_bias_v
def forward(self, query, key, value, pos_emb, mask):
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)
def rel_shift(self, x):
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2
return x
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.linear_out(context_layer) # (batch, time1, d_model)

View File

@@ -0,0 +1,288 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch import nn
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask = mask.detach()
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def sequence_mask_scripts(lengths, maxlen:int):
row_vector = torch.arange(0, maxlen, 1).type(lengths.dtype).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
return mask.type(torch.float32).to(lengths.device)
class CifPredictorV2(nn.Module):
def __init__(self, model):
super().__init__()
self.pad = model.pad
self.cif_conv1d = model.cif_conv1d
self.cif_output = model.cif_output
self.threshold = model.threshold
self.smooth_factor = model.smooth_factor
self.noise_threshold = model.noise_threshold
self.tail_threshold = model.tail_threshold
def forward(self, hidden: torch.Tensor,
mask: torch.Tensor,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
# @torch.jit.script
# def cif(hidden, alphas, threshold: float):
# batch_size, len_time, hidden_size = hidden.size()
# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
#
# # loop varss
# integrate = torch.zeros([batch_size], device=hidden.device)
# frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# # intermediate vars along time
# list_fires = []
# list_frames = []
#
# for t in range(len_time):
# alpha = alphas[:, t]
# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
#
# integrate += alpha
# list_fires.append(integrate)
#
# fire_place = integrate >= threshold
# integrate = torch.where(fire_place,
# integrate - torch.ones([batch_size], device=hidden.device),
# integrate)
# cur = torch.where(fire_place,
# distribution_completion,
# alpha)
# remainds = alpha - cur
#
# frame += cur[:, None] * hidden[:, t, :]
# list_frames.append(frame)
# frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
# remainds[:, None] * hidden[:, t, :],
# frame)
#
# fires = torch.stack(list_fires, 1)
# frames = torch.stack(list_frames, 1)
# list_ls = []
# len_labels = torch.floor(alphas.sum(-1)).int()
# max_label_len = len_labels.max()
# for b in range(batch_size):
# fire = fires[b, :]
# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
# pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
# list_ls.append(torch.cat([l, pad_l], 0))
# return torch.stack(list_ls, 0), fires
@torch.jit.script
def cif(hidden, alphas, threshold: float):
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
# loop varss
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(fire_place,
integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
integrate)
cur = torch.where(fire_place,
distribution_completion,
alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
remainds[:, None] * hidden[:, t, :],
frame)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
fire_idxs = fires >= threshold
frame_fires = torch.zeros_like(hidden)
max_label_len = frames[0, fire_idxs[0]].size(0)
for b in range(batch_size):
frame_fire = frames[b, fire_idxs[b]]
frame_len = frame_fire.size(0)
frame_fires[b, :frame_len, :] = frame_fire
if frame_len >= max_label_len:
max_label_len = frame_len
frame_fires = frame_fires[:, :max_label_len, :]
return frame_fires, fires
class CifPredictorV3(nn.Module):
def __init__(self, model):
super().__init__()
self.pad = model.pad
self.cif_conv1d = model.cif_conv1d
self.cif_output = model.cif_output
self.threshold = model.threshold
self.smooth_factor = model.smooth_factor
self.noise_threshold = model.noise_threshold
self.tail_threshold = model.tail_threshold
self.upsample_times = model.upsample_times
self.upsample_cnn = model.upsample_cnn
self.blstm = model.blstm
self.cif_output2 = model.cif_output2
self.smooth_factor2 = model.smooth_factor2
self.noise_threshold2 = model.noise_threshold2
def forward(self, hidden: torch.Tensor,
mask: torch.Tensor,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
# generate alphas2
_output = context
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
output2, (_, _) = self.blstm(output2)
alphas2 = torch.sigmoid(self.cif_output2(output2))
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
mask = mask.unsqueeze(-1)
alphas2 = alphas2 * mask
alphas2 = alphas2.squeeze(-1)
_token_num = alphas2.sum(-1)
alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
# upsampled alphas and cif_peak
us_alphas = alphas2
us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
return us_alphas, us_cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
@torch.jit.script
def cif_wo_hidden(alphas, threshold: float):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(fire_place,
integrate - torch.ones([batch_size], device=alphas.device),
integrate)
fires = torch.stack(list_fires, 1)
return fires

View File

View File

@@ -0,0 +1,20 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(feats_length):
return {'speech': np.zeros((1, feats_length, 560), dtype=np.float32), 'speech_lengths': np.array([feats_length,], dtype=np.int32)}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value.shape))
_run(_get_feed_dict(100))
_run(_get_feed_dict(200))

View File

@@ -0,0 +1,18 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(text_length):
return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value))
_run(_get_feed_dict(10))

View File

@@ -0,0 +1,22 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(text_length):
return {'inputs': np.ones((1, text_length), dtype=np.int64),
'text_lengths': np.array([text_length,], dtype=np.int32),
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value))
_run(_get_feed_dict(10))

View File

@@ -0,0 +1,26 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(feats_length):
return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value.shape))
_run(_get_feed_dict(100))
_run(_get_feed_dict(200))

View File

@@ -0,0 +1,17 @@
import torch
import numpy as np
if __name__ == '__main__':
onnx_path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
loaded = torch.jit.load(onnx_path)
x = torch.rand([2, 21, 560])
x_len = torch.IntTensor([6, 21])
res = loaded(x, x_len)
print(res[0].size(), res[1])
x = torch.rand([5, 50, 560])
x_len = torch.IntTensor([6, 21, 10, 30, 50])
res = loaded(x, x_len)
print(res[0].size(), res[1])

View File

View File

@@ -0,0 +1,80 @@
from typing import Optional
import torch
import torch.nn as nn
import numpy as np
class MakePadMask(nn.Module):
def __init__(self, max_seq_len=512, flip=True):
super().__init__()
if flip:
self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
else:
self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.
This implementation creates the same mask tensor with original make_pad_mask,
which can be converted into onnx format.
Dimension length of xs should be 2 or 3.
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if xs is not None and len(xs.shape) == 3:
if length_dim == 1:
lengths = lengths.unsqueeze(1).expand(
*xs.transpose(1, 2).shape[:2])
else:
lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
if maxlen is not None:
m = maxlen
elif xs is not None:
m = xs.shape[-1]
else:
m = torch.max(lengths)
mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
if length_dim == 1:
return mask.transpose(1, 2)
else:
return mask
class sequence_mask(nn.Module):
def __init__(self, max_seq_len=512, flip=True):
super().__init__()
def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
if max_seq_len is None:
max_seq_len = lengths.max()
row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None:
denom = input.norm(p, dim, keepdim=True).expand_as(input)
return input / denom
else:
denom = input.norm(p, dim, keepdim=True).expand_as(input)
return torch.div(input, denom, out=out)
def subsequent_mask(size: torch.Tensor):
return torch.ones(size, size).tril()
def MakePadMask_test():
feats_length = torch.tensor([10]).type(torch.long)
mask_fn = MakePadMask()
mask = mask_fn(feats_length)
print(mask)
if __name__ == '__main__':
MakePadMask_test()

View File

View File

@@ -0,0 +1,78 @@
from pathlib import Path
from typing import Union
import warnings
from typeguard import check_argument_types
from typeguard import check_return_type
class DatadirWriter:
"""Writer class to create kaldi like data directory.
Examples:
>>> with DatadirWriter("output") as writer:
... # output/sub.txt is created here
... subwriter = writer["sub.txt"]
... # Write "uttidA some/where/a.wav"
... subwriter["uttidA"] = "some/where/a.wav"
... subwriter["uttidB"] = "some/where/b.wav"
"""
def __init__(self, p: Union[Path, str]):
assert check_argument_types()
self.path = Path(p)
self.chilidren = {}
self.fd = None
self.has_children = False
self.keys = set()
def __enter__(self):
return self
def __getitem__(self, key: str) -> "DatadirWriter":
assert check_argument_types()
if self.fd is not None:
raise RuntimeError("This writer points out a file")
if key not in self.chilidren:
w = DatadirWriter((self.path / key))
self.chilidren[key] = w
self.has_children = True
retval = self.chilidren[key]
assert check_return_type(retval)
return retval
def __setitem__(self, key: str, value: str):
assert check_argument_types()
if self.has_children:
raise RuntimeError("This writer points out a directory")
if key in self.keys:
warnings.warn(f"Duplicated: {key}")
if self.fd is None:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.fd = self.path.open("w", encoding="utf-8")
self.keys.add(key)
self.fd.write(f"{key} {value}\n")
self.fd.flush()
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
if self.has_children:
prev_child = None
for child in self.chilidren.values():
child.close()
if prev_child is not None and prev_child.keys != child.keys:
warnings.warn(
f"Ids are mismatching between "
f"{prev_child.path} and {child.path}"
)
prev_child = child
elif self.fd is not None:
self.fd.close()

View File

@@ -0,0 +1,97 @@
import collections.abc
from pathlib import Path
from typing import Union
import numpy as np
from typeguard import check_argument_types
from funasr_local.fileio.read_text import read_2column_text
class NpyScpWriter:
"""Writer class for a scp file of numpy file.
Examples:
key1 /some/path/a.npy
key2 /some/path/b.npy
key3 /some/path/c.npy
key4 /some/path/d.npy
...
>>> writer = NpyScpWriter('./data/', './data/feat.scp')
>>> writer['aa'] = numpy_array
>>> writer['bb'] = numpy_array
"""
def __init__(self, outdir: Union[Path, str], scpfile: Union[Path, str]):
assert check_argument_types()
self.dir = Path(outdir)
self.dir.mkdir(parents=True, exist_ok=True)
scpfile = Path(scpfile)
scpfile.parent.mkdir(parents=True, exist_ok=True)
self.fscp = scpfile.open("w", encoding="utf-8")
self.data = {}
def get_path(self, key):
return self.data[key]
def __setitem__(self, key, value):
assert isinstance(value, np.ndarray), type(value)
p = self.dir / f"{key}.npy"
p.parent.mkdir(parents=True, exist_ok=True)
np.save(str(p), value)
self.fscp.write(f"{key} {p}\n")
# Store the file path
self.data[key] = str(p)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
self.fscp.close()
class NpyScpReader(collections.abc.Mapping):
"""Reader class for a scp file of numpy file.
Examples:
key1 /some/path/a.npy
key2 /some/path/b.npy
key3 /some/path/c.npy
key4 /some/path/d.npy
...
>>> reader = NpyScpReader('npy.scp')
>>> array = reader['key1']
"""
def __init__(self, fname: Union[Path, str]):
assert check_argument_types()
self.fname = Path(fname)
self.data = read_2column_text(fname)
def get_path(self, key):
return self.data[key]
def __getitem__(self, key) -> np.ndarray:
p = self.data[key]
return np.load(p)
def __contains__(self, item):
return item
def __len__(self):
return len(self.data)
def __iter__(self):
return iter(self.data)
def keys(self):
return self.data.keys()

View File

@@ -0,0 +1,86 @@
import collections
from pathlib import Path
from typing import Union
import numpy as np
from typeguard import check_argument_types
from funasr_local.fileio.read_text import load_num_sequence_text
class FloatRandomGenerateDataset(collections.abc.Mapping):
"""Generate float array from shape.txt.
Examples:
shape.txt
uttA 123,83
uttB 34,83
>>> dataset = FloatRandomGenerateDataset("shape.txt")
>>> array = dataset["uttA"]
>>> assert array.shape == (123, 83)
>>> array = dataset["uttB"]
>>> assert array.shape == (34, 83)
"""
def __init__(
self,
shape_file: Union[Path, str],
dtype: Union[str, np.dtype] = "float32",
loader_type: str = "csv_int",
):
assert check_argument_types()
shape_file = Path(shape_file)
self.utt2shape = load_num_sequence_text(shape_file, loader_type)
self.dtype = np.dtype(dtype)
def __iter__(self):
return iter(self.utt2shape)
def __len__(self):
return len(self.utt2shape)
def __getitem__(self, item) -> np.ndarray:
shape = self.utt2shape[item]
return np.random.randn(*shape).astype(self.dtype)
class IntRandomGenerateDataset(collections.abc.Mapping):
"""Generate float array from shape.txt
Examples:
shape.txt
uttA 123,83
uttB 34,83
>>> dataset = IntRandomGenerateDataset("shape.txt", low=0, high=10)
>>> array = dataset["uttA"]
>>> assert array.shape == (123, 83)
>>> array = dataset["uttB"]
>>> assert array.shape == (34, 83)
"""
def __init__(
self,
shape_file: Union[Path, str],
low: int,
high: int = None,
dtype: Union[str, np.dtype] = "int64",
loader_type: str = "csv_int",
):
assert check_argument_types()
shape_file = Path(shape_file)
self.utt2shape = load_num_sequence_text(shape_file, loader_type)
self.dtype = np.dtype(dtype)
self.low = low
self.high = high
def __iter__(self):
return iter(self.utt2shape)
def __len__(self):
return len(self.utt2shape)
def __getitem__(self, item) -> np.ndarray:
shape = self.utt2shape[item]
return np.random.randint(self.low, self.high, size=shape, dtype=self.dtype)

View File

@@ -0,0 +1,81 @@
import logging
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union
from typeguard import check_argument_types
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object.
Examples:
wav.scp:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2column_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
assert check_argument_types()
data = {}
with Path(path).open("r", encoding="utf-8") as f:
for linenum, line in enumerate(f, 1):
sps = line.rstrip().split(maxsplit=1)
if len(sps) == 1:
k, v = sps[0], ""
else:
k, v = sps
if k in data:
raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
data[k] = v
return data
def load_num_sequence_text(
path: Union[Path, str], loader_type: str = "csv_int"
) -> Dict[str, List[Union[float, int]]]:
"""Read a text file indicating sequences of number
Examples:
key1 1 2 3
key2 34 5 6
>>> d = load_num_sequence_text('text')
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
"""
assert check_argument_types()
if loader_type == "text_int":
delimiter = " "
dtype = int
elif loader_type == "text_float":
delimiter = " "
dtype = float
elif loader_type == "csv_int":
delimiter = ","
dtype = int
elif loader_type == "csv_float":
delimiter = ","
dtype = float
else:
raise ValueError(f"Not supported loader_type={loader_type}")
# path looks like:
# utta 1,0
# uttb 3,4,5
# -> return {'utta': np.ndarray([1, 0]),
# 'uttb': np.ndarray([3, 4, 5])}
d = read_2column_text(path)
# Using for-loop instead of dict-comprehension for debuggability
retval = {}
for k, v in d.items():
try:
retval[k] = [dtype(i) for i in v.split(delimiter)]
except TypeError:
logging.error(f'Error happened with path="{path}", id="{k}", value="{v}"')
raise
return retval

View File

@@ -0,0 +1,136 @@
import collections.abc
from pathlib import Path
from typing import Union
import numpy as np
import soundfile
import librosa
from typeguard import check_argument_types
from funasr_local.fileio.read_text import read_2column_text
class SoundScpReader(collections.abc.Mapping):
"""Reader class for 'wav.scp'.
Examples:
key1 /some/path/a.wav
key2 /some/path/b.wav
key3 /some/path/c.wav
key4 /some/path/d.wav
...
>>> reader = SoundScpReader('wav.scp')
>>> rate, array = reader['key1']
"""
def __init__(
self,
fname,
dtype=np.int16,
always_2d: bool = False,
normalize: bool = False,
dest_sample_rate: int = 16000,
):
assert check_argument_types()
self.fname = fname
self.dtype = dtype
self.always_2d = always_2d
self.normalize = normalize
self.data = read_2column_text(fname)
self.dest_sample_rate = dest_sample_rate
def __getitem__(self, key):
wav = self.data[key]
if self.normalize:
# soundfile.read normalizes data to [-1,1] if dtype is not given
array, rate = librosa.load(
wav, sr=self.dest_sample_rate, mono=not self.always_2d
)
else:
array, rate = librosa.load(
wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype
)
return rate, array
def get_path(self, key):
return self.data[key]
def __contains__(self, item):
return item
def __len__(self):
return len(self.data)
def __iter__(self):
return iter(self.data)
def keys(self):
return self.data.keys()
class SoundScpWriter:
"""Writer class for 'wav.scp'
Examples:
key1 /some/path/a.wav
key2 /some/path/b.wav
key3 /some/path/c.wav
key4 /some/path/d.wav
...
>>> writer = SoundScpWriter('./data/', './data/feat.scp')
>>> writer['aa'] = 16000, numpy_array
>>> writer['bb'] = 16000, numpy_array
"""
def __init__(
self,
outdir: Union[Path, str],
scpfile: Union[Path, str],
format="wav",
dtype=None,
):
assert check_argument_types()
self.dir = Path(outdir)
self.dir.mkdir(parents=True, exist_ok=True)
scpfile = Path(scpfile)
scpfile.parent.mkdir(parents=True, exist_ok=True)
self.fscp = scpfile.open("w", encoding="utf-8")
self.format = format
self.dtype = dtype
self.data = {}
def __setitem__(self, key: str, value):
rate, signal = value
assert isinstance(rate, int), type(rate)
assert isinstance(signal, np.ndarray), type(signal)
if signal.ndim not in (1, 2):
raise RuntimeError(f"Input signal must be 1 or 2 dimension: {signal.ndim}")
if signal.ndim == 1:
signal = signal[:, None]
wav = self.dir / f"{key}.{self.format}"
wav.parent.mkdir(parents=True, exist_ok=True)
soundfile.write(str(wav), signal, rate)
self.fscp.write(f"{key} {wav}\n")
# Store the file path
self.data[key] = str(wav)
def get_path(self, key):
return self.data[key]
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
self.fscp.close()

View File

View File

@@ -0,0 +1,9 @@
from abc import ABC
from abc import abstractmethod
from typing import Iterator
class AbsIterFactory(ABC):
@abstractmethod
def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
raise NotImplementedError

View File

@@ -0,0 +1,215 @@
import logging
from typing import Any
from typing import Dict
from typing import Iterator
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr_local.iterators.abs_iter_factory import AbsIterFactory
from funasr_local.iterators.sequence_iter_factory import SequenceIterFactory
from funasr_local.samplers.abs_sampler import AbsSampler
class ChunkIterFactory(AbsIterFactory):
"""Creates chunks from a sequence
Examples:
>>> batches = [["id1"], ["id2"], ...]
>>> batch_size = 128
>>> chunk_length = 1000
>>> iter_factory = ChunkIterFactory(dataset, batches, batch_size, chunk_length)
>>> it = iter_factory.build_iter(epoch)
>>> for ids, batch in it:
... ...
- The number of mini-batches are varied in each epochs and
we can't get the number in advance
because IterFactory doesn't be given to the length information.
- Since the first reason, "num_iters_per_epoch" can't be implemented
for this iterator. Instead of it, "num_samples_per_epoch" is implemented.
"""
def __init__(
self,
dataset,
batch_size: int,
batches: Union[AbsSampler, Sequence[Sequence[Any]]],
chunk_length: Union[int, str],
chunk_shift_ratio: float = 0.5,
num_cache_chunks: int = 1024,
num_samples_per_epoch: int = None,
seed: int = 0,
shuffle: bool = False,
num_workers: int = 0,
collate_fn=None,
pin_memory: bool = False,
):
assert check_argument_types()
assert all(len(x) == 1 for x in batches), "batch-size must be 1"
self.per_sample_iter_factory = SequenceIterFactory(
dataset=dataset,
batches=batches,
num_iters_per_epoch=num_samples_per_epoch,
seed=seed,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
)
self.num_cache_chunks = max(num_cache_chunks, batch_size)
if isinstance(chunk_length, str):
if len(chunk_length) == 0:
raise ValueError("e.g. 5,8 or 3-5: but got empty string")
self.chunk_lengths = []
for x in chunk_length.split(","):
try:
sps = list(map(int, x.split("-")))
except ValueError:
raise ValueError(f"e.g. 5,8 or 3-5: but got {chunk_length}")
if len(sps) > 2:
raise ValueError(f"e.g. 5,8 or 3-5: but got {chunk_length}")
elif len(sps) == 2:
# Append all numbers between the range into the candidates
self.chunk_lengths += list(range(sps[0], sps[1] + 1))
else:
self.chunk_lengths += [sps[0]]
else:
# Single candidates: Fixed chunk length
self.chunk_lengths = [chunk_length]
self.chunk_shift_ratio = chunk_shift_ratio
self.batch_size = batch_size
self.seed = seed
self.shuffle = shuffle
def build_iter(
self,
epoch: int,
shuffle: bool = None,
) -> Iterator[Tuple[List[str], Dict[str, torch.Tensor]]]:
per_sample_loader = self.per_sample_iter_factory.build_iter(epoch, shuffle)
if shuffle is None:
shuffle = self.shuffle
state = np.random.RandomState(epoch + self.seed)
# NOTE(kamo):
# This iterator supports multiple chunk lengths and
# keep chunks for each lengths here until collecting specified numbers
cache_chunks_dict = {}
cache_id_list_dict = {}
for ids, batch in per_sample_loader:
# Must be per-sample-loader
assert len(ids) == 1, f"Must be per-sample-loader: {len(ids)}"
assert all(len(x) == 1 for x in batch.values())
# Get keys of sequence data
sequence_keys = []
for key in batch:
if key + "_lengths" in batch:
sequence_keys.append(key)
# Remove lengths data and get the first sample
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
id_ = ids[0]
for key in sequence_keys:
if len(batch[key]) != len(batch[sequence_keys[0]]):
raise RuntimeError(
f"All sequences must has same length: "
f"{len(batch[key])} != {len(batch[sequence_keys[0]])}"
)
L = len(batch[sequence_keys[0]])
# Select chunk length
chunk_lengths = [lg for lg in self.chunk_lengths if lg < L]
if len(chunk_lengths) == 0:
logging.warning(
f"The length of '{id_}' is {L}, but it is shorter than "
f"any candidates of chunk-length: {self.chunk_lengths}"
)
continue
W = int(state.choice(chunk_lengths, 1))
cache_id_list = cache_id_list_dict.setdefault(W, [])
cache_chunks = cache_chunks_dict.setdefault(W, {})
# Shift width to the next chunk
S = int(W * self.chunk_shift_ratio)
# Number of chunks
N = (L - W) // S + 1
if shuffle:
Z = state.randint(0, (L - W) % S + 1)
else:
Z = 0
# Split a sequence into chunks.
# Note that the marginal frames divided by chunk length are discarded
for k, v in batch.items():
if k not in cache_chunks:
cache_chunks[k] = []
if k in sequence_keys:
# Shift chunks with overlapped length for data augmentation
cache_chunks[k] += [v[Z + i * S : Z + i * S + W] for i in range(N)]
else:
# If not sequence, use whole data instead of chunk
cache_chunks[k] += [v for _ in range(N)]
cache_id_list += [id_ for _ in range(N)]
if len(cache_id_list) > self.num_cache_chunks:
cache_id_list, cache_chunks = yield from self._generate_mini_batches(
cache_id_list,
cache_chunks,
shuffle,
state,
)
cache_id_list_dict[W] = cache_id_list
cache_chunks_dict[W] = cache_chunks
else:
for W in cache_id_list_dict:
cache_id_list = cache_id_list_dict.setdefault(W, [])
cache_chunks = cache_chunks_dict.setdefault(W, {})
yield from self._generate_mini_batches(
cache_id_list,
cache_chunks,
shuffle,
state,
)
def _generate_mini_batches(
self,
id_list: List[str],
batches: Dict[str, List[torch.Tensor]],
shuffle: bool,
state: np.random.RandomState,
):
if shuffle:
indices = np.arange(0, len(id_list))
state.shuffle(indices)
batches = {k: [v[i] for i in indices] for k, v in batches.items()}
id_list = [id_list[i] for i in indices]
bs = self.batch_size
while len(id_list) >= bs:
# Make mini-batch and yield
yield (
id_list[:bs],
{k: torch.stack(v[:bs], 0) for k, v in batches.items()},
)
id_list = id_list[bs:]
batches = {k: v[bs:] for k, v in batches.items()}
return id_list, batches

View File

@@ -0,0 +1,37 @@
import logging
from typing import Callable
from typing import Collection
from typing import Iterator
import numpy as np
from typeguard import check_argument_types
from funasr_local.iterators.abs_iter_factory import AbsIterFactory
class MultipleIterFactory(AbsIterFactory):
def __init__(
self,
build_funcs: Collection[Callable[[], AbsIterFactory]],
seed: int = 0,
shuffle: bool = False,
):
assert check_argument_types()
self.build_funcs = list(build_funcs)
self.seed = seed
self.shuffle = shuffle
def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
if shuffle is None:
shuffle = self.shuffle
build_funcs = list(self.build_funcs)
if shuffle:
np.random.RandomState(epoch + self.seed).shuffle(build_funcs)
for i, build_func in enumerate(build_funcs):
logging.info(f"Building {i}th iter-factory...")
iter_factory = build_func()
assert isinstance(iter_factory, AbsIterFactory), type(iter_factory)
yield from iter_factory.build_iter(epoch, shuffle)

View File

@@ -0,0 +1,143 @@
from typing import Any
from typing import Sequence
from typing import Union
import numpy as np
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr_local.iterators.abs_iter_factory import AbsIterFactory
from funasr_local.samplers.abs_sampler import AbsSampler
class RawSampler(AbsSampler):
def __init__(self, batches):
self.batches = batches
def __len__(self):
return len(self.batches)
def __iter__(self):
return iter(self.batches)
def generate(self, seed):
return list(self.batches)
class SequenceIterFactory(AbsIterFactory):
"""Build iterator for each epoch.
This class simply creates pytorch DataLoader except for the following points:
- The random seed is decided according to the number of epochs. This feature
guarantees reproducibility when resuming from middle of training process.
- Enable to restrict the number of samples for one epoch. This features
controls the interval number between training and evaluation.
"""
def __init__(
self,
dataset,
batches: Union[AbsSampler, Sequence[Sequence[Any]]],
num_iters_per_epoch: int = None,
seed: int = 0,
shuffle: bool = False,
num_workers: int = 0,
collate_fn=None,
pin_memory: bool = False,
):
assert check_argument_types()
if not isinstance(batches, AbsSampler):
self.sampler = RawSampler(batches)
else:
self.sampler = batches
self.dataset = dataset
self.num_iters_per_epoch = num_iters_per_epoch
self.shuffle = shuffle
self.seed = seed
self.num_workers = num_workers
self.collate_fn = collate_fn
# https://discuss.pytorch.org/t/what-is-the-disadvantage-of-using-pin-memory/1702
self.pin_memory = pin_memory
def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
if shuffle is None:
shuffle = self.shuffle
if self.num_iters_per_epoch is not None:
N = len(self.sampler)
# If corpus size is larger than the num_per_epoch
if self.num_iters_per_epoch < N:
N = len(self.sampler)
real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
if offset >= self.num_iters_per_epoch:
current_batches = self.sampler.generate(real_epoch + self.seed)
if shuffle:
np.random.RandomState(real_epoch + self.seed).shuffle(
current_batches
)
batches = current_batches[
offset - self.num_iters_per_epoch : offset
]
else:
prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
current_batches = self.sampler.generate(real_epoch + self.seed)
if shuffle:
np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
prev_batches
)
np.random.RandomState(real_epoch + self.seed).shuffle(
current_batches
)
batches = (
prev_batches[offset - self.num_iters_per_epoch :]
+ current_batches[:offset]
)
# If corpus size is less than the num_per_epoch
else:
_epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
_remain = self.num_iters_per_epoch
batches = []
current_batches = self.sampler.generate(_epoch + self.seed)
if shuffle:
np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
while _remain > 0:
_batches = current_batches[_cursor : _cursor + _remain]
batches += _batches
if _cursor + _remain >= N:
_epoch += 1
_cursor = 0
current_batches = self.sampler.generate(_epoch + self.seed)
if shuffle:
np.random.RandomState(_epoch + self.seed).shuffle(
current_batches
)
else:
_cursor = _cursor + _remain
_remain -= len(_batches)
assert len(batches) == self.num_iters_per_epoch
else:
batches = self.sampler.generate(epoch + self.seed)
if shuffle:
np.random.RandomState(epoch + self.seed).shuffle(batches)
# For backward compatibility for pytorch DataLoader
if self.collate_fn is not None:
kwargs = dict(collate_fn=self.collate_fn)
else:
kwargs = {}
return DataLoader(
dataset=self.dataset,
batch_sampler=batches,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
**kwargs,
)

Some files were not shown because too many files have changed in this diff Show More