add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

View File

View File

@@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import ssl
import nltk
# mkdir nltk_data dir if not exist
try:
nltk.data.find('.')
except LookupError:
dir_list = nltk.data.path
for dir_item in dir_list:
if not os.path.exists(dir_item):
os.mkdir(dir_item)
if os.path.exists(dir_item):
break
# download one package if nltk_data not exist
try:
nltk.data.find('.')
except: # noqa: *
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
# deploy taggers/averaged_perceptron_tagger
try:
nltk.data.find('taggers/averaged_perceptron_tagger')
except: # noqa: *
data_dir = nltk.data.find('.')
target_dir = os.path.join(data_dir, 'taggers')
if not os.path.exists(target_dir):
os.mkdir(target_dir)
src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
'averaged_perceptron_tagger.zip')
shutil.copyfile(src_file,
os.path.join(target_dir, 'averaged_perceptron_tagger.zip'))
shutil._unpack_zipfile(
os.path.join(target_dir, 'averaged_perceptron_tagger.zip'), target_dir)
# deploy corpora/cmudict
try:
nltk.data.find('corpora/cmudict')
except: # noqa: *
data_dir = nltk.data.find('.')
target_dir = os.path.join(data_dir, 'corpora')
if not os.path.exists(target_dir):
os.mkdir(target_dir)
src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
'cmudict.zip')
shutil.copyfile(src_file, os.path.join(target_dir, 'cmudict.zip'))
shutil._unpack_zipfile(os.path.join(target_dir, 'cmudict.zip'), target_dir)
try:
nltk.data.find('taggers/averaged_perceptron_tagger')
except: # noqa: *
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download('averaged_perceptron_tagger',
halt_on_error=False,
raise_on_error=True)
try:
nltk.data.find('corpora/cmudict')
except: # noqa: *
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)

View File

@@ -0,0 +1,355 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import struct
from typing import Any, Dict, List, Union
import torchaudio
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
logger = get_logger()
green_color = '\033[1;32m'
red_color = '\033[0;31;40m'
yellow_color = '\033[0;33;40m'
end_color = '\033[0m'
global_asr_language = 'zh-cn'
SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
def get_version():
return float(pkg_resources.get_distribution('easyasr').version)
def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
r_audio_fs = None
if audio_format == 'wav' or audio_format == 'scp':
r_audio_fs = get_sr_from_wav(audio_in)
elif audio_format == 'pcm' and isinstance(audio_in, bytes):
r_audio_fs = get_sr_from_bytes(audio_in)
return r_audio_fs
def type_checking(audio_in: Union[str, bytes],
audio_fs: int = None,
recog_type: str = None,
audio_format: str = None):
r_recog_type = recog_type
r_audio_format = audio_format
r_wav_path = audio_in
if isinstance(audio_in, str):
assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
elif isinstance(audio_in, bytes):
assert len(audio_in) > 0, 'audio in is empty'
r_audio_format = 'pcm'
r_recog_type = 'wav'
if audio_in is None:
# for raw_inputs
r_recog_type = 'wav'
r_audio_format = 'pcm'
if r_recog_type is None and audio_in is not None:
# audio_in is wav, recog_type is wav_file
if os.path.isfile(audio_in):
audio_type = os.path.basename(audio_in).lower()
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
r_recog_type = 'wav'
r_audio_format = 'wav'
if audio_type.rfind(".scp") >= 0:
r_recog_type = 'wav'
r_audio_format = 'scp'
if r_recog_type is None:
raise NotImplementedError(
f'Not supported audio type: {audio_type}')
# recog_type is datasets_file
elif os.path.isdir(audio_in):
dir_name = os.path.basename(audio_in)
if 'test' in dir_name:
r_recog_type = 'test'
elif 'dev' in dir_name:
r_recog_type = 'dev'
elif 'train' in dir_name:
r_recog_type = 'train'
if r_audio_format is None:
if find_file_by_ends(audio_in, '.ark'):
r_audio_format = 'kaldi_ark'
elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
audio_in, '.WAV'):
r_audio_format = 'wav'
elif find_file_by_ends(audio_in, '.records'):
r_audio_format = 'tfrecord'
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
# datasets with kaldi_ark file
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
# datasets with tensorflow records file
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
elif r_audio_format == 'wav' and r_recog_type != 'wav':
# datasets with waveform files
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
return r_recog_type, r_audio_format, r_wav_path
def get_sr_from_bytes(wav: bytes):
sr = None
data = wav
if len(data) > 44:
try:
header_fields = {}
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
header_fields['Format'] = str(data[8:12], 'UTF-8')
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
if header_fields['ChunkID'] == 'RIFF' and header_fields[
'Format'] == 'WAVE' and header_fields[
'Subchunk1ID'] == 'fmt ':
header_fields['SampleRate'] = struct.unpack('<I',
data[24:28])[0]
sr = header_fields['SampleRate']
except Exception:
# no treatment
pass
else:
logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
return sr
def get_sr_from_wav(fname: str):
fs = None
if os.path.isfile(fname):
audio_type = os.path.basename(fname).lower()
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
if support_audio_type == "pcm":
fs = None
else:
audio, fs = torchaudio.load(fname)
break
if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f:
for line in f:
wav_path = line.split()[1]
fs = get_sr_from_wav(wav_path)
if fs is not None:
break
return fs
elif os.path.isdir(fname):
dir_files = os.listdir(fname)
for file in dir_files:
file_path = os.path.join(fname, file)
if os.path.isfile(file_path):
fs = get_sr_from_wav(file_path)
elif os.path.isdir(file_path):
fs = get_sr_from_wav(file_path)
if fs is not None:
break
return fs
def find_file_by_ends(dir_path: str, ends: str):
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if ends == ".wav" or ends == ".WAV":
audio_type = os.path.basename(file_path).lower()
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
return True
raise NotImplementedError(
f'Not supported audio type: {audio_type}')
elif file_path.endswith(ends):
return True
elif os.path.isdir(file_path):
if find_file_by_ends(file_path, ends):
return True
return False
def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
audio_type = os.path.basename(file_path).lower()
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
wav_list.append(file_path)
elif os.path.isdir(file_path):
recursion_dir_all_wav(wav_list, file_path)
return wav_list
def compute_wer(hyp_list: List[Any],
ref_list: List[Any],
lang: str = None) -> Dict[str, Any]:
assert len(hyp_list) > 0, 'hyp list is empty'
assert len(ref_list) > 0, 'ref list is empty'
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
if lang is None:
lang = global_asr_language
for h_item in hyp_list:
for r_item in ref_list:
if h_item['key'] == r_item['key']:
out_item = compute_wer_by_line(h_item['value'],
r_item['value'],
lang)
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
print_wrong_sentence(key=h_item['key'],
hyp=h_item['value'],
ref=r_item['value'])
else:
print_correct_sentence(key=h_item['key'],
hyp=h_item['value'],
ref=r_item['value'])
break
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
return rst
def compute_wer_by_line(hyp: List[str],
ref: List[str],
lang: str = 'zh-cn') -> Dict[str, Any]:
if lang != 'zh-cn':
hyp = hyp.split()
ref = ref.split()
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst
def print_wrong_sentence(key: str, hyp: str, ref: str):
space = len(key)
print(key + yellow_color + ' ref: ' + ref)
print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
def print_correct_sentence(key: str, hyp: str, ref: str):
space = len(key)
print(key + yellow_color + ' ref: ' + ref)
print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
def print_progress(percent):
if percent > 1:
percent = 1
res = int(50 * percent) * '#'
print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')

View File

@@ -0,0 +1,17 @@
import argparse
import dataclasses
from typeguard import check_type
def build_dataclass(dataclass, args: argparse.Namespace):
"""Helper function to build dataclass from 'args'."""
kwargs = {}
for field in dataclasses.fields(dataclass):
if not hasattr(args, field.name):
raise ValueError(
f"args doesn't have {field.name}. You need to set it to ArgumentsParser"
)
check_type(field.name, getattr(args, field.name), field.type)
kwargs[field.name] = getattr(args, field.name)
return dataclass(**kwargs)

View File

@@ -0,0 +1,65 @@
from collections.abc import Sequence
from distutils.util import strtobool as dist_strtobool
import sys
import numpy
def strtobool(x):
# distutils.util.strtobool returns integer, but it's confusing,
return bool(dist_strtobool(x))
def get_commandline_args():
extra_chars = [
" ",
";",
"&",
"(",
")",
"|",
"^",
"<",
">",
"?",
"*",
"[",
"]",
"$",
"`",
'"',
"\\",
"!",
"{",
"}",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''")
if all(char not in arg for char in extra_chars)
else "'" + arg.replace("'", "'\\''") + "'"
for arg in sys.argv
]
return sys.executable + " " + " ".join(argv)
def is_scipy_wav_style(value):
# If Tuple[int, numpy.ndarray] or not
return (
isinstance(value, Sequence)
and len(value) == 2
and isinstance(value[0], int)
and isinstance(value[1], numpy.ndarray)
)
def assert_scipy_wav_style(value):
assert is_scipy_wav_style(
value
), "Must be Tuple[int, numpy.ndarray], but got {}".format(
type(value)
if not isinstance(value, Sequence)
else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value))
)

View File

@@ -0,0 +1,59 @@
import numpy as np
from sklearn.metrics import roc_curve
import argparse
def _compute_eer(label, pred, positive_label=1):
"""
Python compute equal error rate (eer)
ONLY tested on binary classification
:param label: ground-truth label, should be a 1-d list or np.array, each element represents the ground-truth label of one sample
:param pred: model prediction, should be a 1-d list or np.array, each element represents the model prediction of one sample
:param positive_label: the class that is viewed as positive class when computing EER
:return: equal error rate (EER)
"""
# all fpr, tpr, fnr, fnr, threshold are lists (in the format of np.array)
fpr, tpr, threshold = roc_curve(label, pred, pos_label=positive_label)
fnr = 1 - tpr
# the threshold of fnr == fpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
# theoretically eer from fpr and eer from fnr should be identical but they can be slightly differ in reality
eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
# return the mean of eer from fpr and from fnr
eer = (eer_1 + eer_2) / 2
return eer, eer_threshold
def compute_eer(trials_path, scores_path):
labels = []
for one_line in open(trials_path, "r"):
labels.append(one_line.strip().rsplit(" ", 1)[-1] == "target")
labels = np.array(labels, dtype=int)
scores = []
for one_line in open(scores_path, "r"):
scores.append(float(one_line.strip().rsplit(" ", 1)[-1]))
scores = np.array(scores, dtype=float)
eer, threshold = _compute_eer(labels, scores)
return eer, threshold
def main():
parser = argparse.ArgumentParser()
parser.add_argument("trials", help="trial list")
parser.add_argument("scores", help="score file, normalized to [0, 1]")
args = parser.parse_args()
eer, threshold = compute_eer(args.trials, args.scores)
print("EER is {:.4f} at threshold {:.4f}".format(eer * 100.0, threshold))
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python3
# Copyright 2018 David Snyder
# Apache 2.0
# This script computes the minimum detection cost function, which is a common
# error metric used in speaker recognition. Compared to equal error-rate,
# which assigns equal weight to false negatives and false positives, this
# error-rate is usually used to assess performance in settings where achieving
# a low false positive rate is more important than achieving a low false
# negative rate. See the NIST 2016 Speaker Recognition Evaluation Plan at
# https://www.nist.gov/sites/default/files/documents/2016/10/07/sre16_eval_plan_v1.3.pdf
# for more details about the metric.
from __future__ import print_function
from operator import itemgetter
import sys, argparse, os
def GetArgs():
parser = argparse.ArgumentParser(description="Compute the minimum "
"detection cost function along with the threshold at which it occurs. "
"Usage: sid/compute_min_dcf.py [options...] <scores-file> "
"<trials-file> "
"E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
"exp/scores/trials data/test/trials",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--p-target', type=float, dest="p_target",
default=0.01,
help='The prior probability of the target speaker in a trial.')
parser.add_argument('--c-miss', type=float, dest="c_miss", default=1,
help='Cost of a missed detection. This is usually not changed.')
parser.add_argument('--c-fa', type=float, dest="c_fa", default=1,
help='Cost of a spurious detection. This is usually not changed.')
parser.add_argument("scores_filename",
help="Input scores file, with columns of the form "
"<utt1> <utt2> <score>")
parser.add_argument("trials_filename",
help="Input trials file, with columns of the form "
"<utt1> <utt2> <target/nontarget>")
sys.stderr.write(' '.join(sys.argv) + "\n")
args = parser.parse_args()
args = CheckArgs(args)
return args
def CheckArgs(args):
if args.c_fa <= 0:
raise Exception("--c-fa must be greater than 0")
if args.c_miss <= 0:
raise Exception("--c-miss must be greater than 0")
if args.p_target <= 0 or args.p_target >= 1:
raise Exception("--p-target must be greater than 0 and less than 1")
return args
# Creates a list of false-negative rates, a list of false-positive rates
# and a list of decision thresholds that give those error-rates.
def ComputeErrorRates(scores, labels):
# Sort the scores from smallest to largest, and also get the corresponding
# indexes of the sorted scores. We will treat the sorted scores as the
# thresholds at which the the error-rates are evaluated.
sorted_indexes, thresholds = zip(*sorted(
[(index, threshold) for index, threshold in enumerate(scores)],
key=itemgetter(1)))
labels = [labels[i] for i in sorted_indexes]
fns = []
tns = []
# At the end of this loop, fns[i] is the number of errors made by
# incorrectly rejecting scores less than thresholds[i]. And, tns[i]
# is the total number of times that we have correctly rejected scores
# less than thresholds[i].
for i in range(0, len(labels)):
if i == 0:
fns.append(labels[i])
tns.append(1 - labels[i])
else:
fns.append(fns[i-1] + labels[i])
tns.append(tns[i-1] + 1 - labels[i])
positives = sum(labels)
negatives = len(labels) - positives
# Now divide the false negatives by the total number of
# positives to obtain the false negative rates across
# all thresholds
fnrs = [fn / float(positives) for fn in fns]
# Divide the true negatives by the total number of
# negatives to get the true negative rate. Subtract these
# quantities from 1 to get the false positive rates.
fprs = [1 - tn / float(negatives) for tn in tns]
return fnrs, fprs, thresholds
# Computes the minimum of the detection cost function. The comments refer to
# equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
min_c_det = float("inf")
min_c_det_threshold = thresholds[0]
for i in range(0, len(fnrs)):
# See Equation (2). it is a weighted sum of false negative
# and false positive errors.
c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
if c_det < min_c_det:
min_c_det = c_det
min_c_det_threshold = thresholds[i]
# See Equations (3) and (4). Now we normalize the cost.
c_def = min(c_miss * p_target, c_fa * (1 - p_target))
min_dcf = min_c_det / c_def
return min_dcf, min_c_det_threshold
def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
scores_file = open(scores_filename, 'r').readlines()
trials_file = open(trials_filename, 'r').readlines()
c_miss = c_miss
c_fa = c_fa
p_target = p_target
scores = []
labels = []
trials = {}
for line in trials_file:
utt1, utt2, target = line.rstrip().split()
trial = utt1 + " " + utt2
trials[trial] = target
for line in scores_file:
utt1, utt2, score = line.rstrip().split()
trial = utt1 + " " + utt2
if trial in trials:
scores.append(float(score))
if trials[trial] == "target":
labels.append(1)
else:
labels.append(0)
else:
raise Exception("Missing entry for " + utt1 + " and " + utt2
+ " " + scores_filename)
fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target,
c_miss, c_fa)
return mindcf, threshold
def main():
args = GetArgs()
mindcf, threshold = compute_min_dcf(
args.scores_filename, args.trials_filename,
args.c_miss, args.c_fa, args.p_target
)
sys.stdout.write("minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
"c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,157 @@
import os
import numpy as np
import sys
def compute_wer(ref_file,
hyp_file,
cer_detail_file):
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
hyp_dict = {}
ref_dict = {}
with open(hyp_file, 'r') as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
hyp_dict[key] = value
with open(ref_file, 'r') as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
ref_dict[key] = value
cer_detail_writer = open(cer_detail_file, 'w')
for hyp_key in hyp_dict:
if hyp_key in ref_dict:
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
cer_detail_writer.write('\n')
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
def compute_wer_by_line(hyp,
ref):
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst
def print_cer_detail(rst):
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
if __name__ == '__main__':
if len(sys.argv) != 4:
print("usage : python compute-wer.py test.ref test.hyp test.wer")
sys.exit(0)
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
cer_detail_file = sys.argv[3]
compute_wer(ref_file, hyp_file, cer_detail_file)

View File

@@ -0,0 +1,47 @@
import argparse
from pathlib import Path
import yaml
class ArgumentParser(argparse.ArgumentParser):
"""Simple implementation of ArgumentParser supporting config file
This class is originated from https://github.com/bw2/ConfigArgParse,
but this class is lack of some features that it has.
- Not supporting multiple config files
- Automatically adding "--config" as an option.
- Not supporting any formats other than yaml
- Not checking argument type
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_argument("--config", help="Give config file in yaml format")
def parse_known_args(self, args=None, namespace=None):
# Once parsing for setting from "--config"
_args, _ = super().parse_known_args(args, namespace)
if _args.config is not None:
if not Path(_args.config).exists():
self.error(f"No such file: {_args.config}")
with open(_args.config, "r", encoding="utf-8") as f:
d = yaml.safe_load(f)
if not isinstance(d, dict):
self.error("Config file has non dict value: {_args.config}")
for key in d:
for action in self._actions:
if key == action.dest:
break
else:
self.error(f"unrecognized arguments: {key} (from {_args.config})")
# NOTE(kamo): Ignore "--config" from a config file
# NOTE(kamo): Unlike "configargparse", this module doesn't check type.
# i.e. We can set any type value regardless of argument type.
self.set_defaults(**d)
return super().parse_known_args(args, namespace)

View File

@@ -0,0 +1,57 @@
import inspect
class Invalid:
"""Marker object for not serializable-object"""
def get_default_kwargs(func):
"""Get the default values of the input function.
Examples:
>>> def func(a, b=3): pass
>>> get_default_kwargs(func)
{'b': 3}
"""
def yaml_serializable(value):
# isinstance(x, tuple) includes namedtuple, so type is used here
if type(value) is tuple:
return yaml_serializable(list(value))
elif isinstance(value, set):
return yaml_serializable(list(value))
elif isinstance(value, dict):
if not all(isinstance(k, str) for k in value):
return Invalid
retval = {}
for k, v in value.items():
v2 = yaml_serializable(v)
# Register only valid object
if v2 not in (Invalid, inspect.Parameter.empty):
retval[k] = v2
return retval
elif isinstance(value, list):
retval = []
for v in value:
v2 = yaml_serializable(v)
# If any elements in the list are invalid,
# the list also becomes invalid
if v2 is Invalid:
return Invalid
else:
retval.append(v2)
return retval
elif value in (inspect.Parameter.empty, None):
return value
elif isinstance(value, (float, int, complex, bool, str, bytes)):
return value
else:
return Invalid
# params: An ordered mapping of inspect.Parameter
params = inspect.signature(func).parameters
data = {p.name: p.default for p in params.values()}
# Remove not yaml-serializable object
data = yaml_serializable(data)
return data

View File

@@ -0,0 +1,192 @@
#!/usr/bin/env python3
"""Griffin-Lim related modules."""
# Copyright 2019 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import logging
from distutils.version import LooseVersion
from functools import partial
from typeguard import check_argument_types
from typing import Optional
import librosa
import numpy as np
import torch
EPS = 1e-10
def logmel2linear(
lmspc: np.ndarray,
fs: int,
n_fft: int,
n_mels: int,
fmin: int = None,
fmax: int = None,
) -> np.ndarray:
"""Convert log Mel filterbank to linear spectrogram.
Args:
lmspc: Log Mel filterbank (T, n_mels).
fs: Sampling frequency.
n_fft: The number of FFT points.
n_mels: The number of mel basis.
f_min: Minimum frequency to analyze.
f_max: Maximum frequency to analyze.
Returns:
Linear spectrogram (T, n_fft // 2 + 1).
"""
assert lmspc.shape[1] == n_mels
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
mspc = np.power(10.0, lmspc)
mel_basis = librosa.filters.mel(
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
)
inv_mel_basis = np.linalg.pinv(mel_basis)
return np.maximum(EPS, np.dot(inv_mel_basis, mspc.T).T)
def griffin_lim(
spc: np.ndarray,
n_fft: int,
n_shift: int,
win_length: int = None,
window: Optional[str] = "hann",
n_iter: Optional[int] = 32,
) -> np.ndarray:
"""Convert linear spectrogram into waveform using Griffin-Lim.
Args:
spc: Linear spectrogram (T, n_fft // 2 + 1).
n_fft: The number of FFT points.
n_shift: Shift size in points.
win_length: Window length in points.
window: Window function type.
n_iter: The number of iterations.
Returns:
Reconstructed waveform (N,).
"""
# assert the size of input linear spectrogram
assert spc.shape[1] == n_fft // 2 + 1
if LooseVersion(librosa.__version__) >= LooseVersion("0.7.0"):
# use librosa's fast Grriffin-Lim algorithm
spc = np.abs(spc.T)
y = librosa.griffinlim(
S=spc,
n_iter=n_iter,
hop_length=n_shift,
win_length=win_length,
window=window,
center=True if spc.shape[1] > 1 else False,
)
else:
# use slower version of Grriffin-Lim algorithm
logging.warning(
"librosa version is old. use slow version of Grriffin-Lim algorithm."
"if you want to use fast Griffin-Lim, please update librosa via "
"`source ./path.sh && pip install librosa==0.7.0`."
)
cspc = np.abs(spc).astype(np.complex).T
angles = np.exp(2j * np.pi * np.random.rand(*cspc.shape))
y = librosa.istft(cspc * angles, n_shift, win_length, window=window)
for i in range(n_iter):
angles = np.exp(
1j
* np.angle(librosa.stft(y, n_fft, n_shift, win_length, window=window))
)
y = librosa.istft(cspc * angles, n_shift, win_length, window=window)
return y
# TODO(kan-bayashi): write as torch.nn.Module
class Spectrogram2Waveform(object):
"""Spectrogram to waveform conversion module."""
def __init__(
self,
n_fft: int,
n_shift: int,
fs: int = None,
n_mels: int = None,
win_length: int = None,
window: Optional[str] = "hann",
fmin: int = None,
fmax: int = None,
griffin_lim_iters: Optional[int] = 8,
):
"""Initialize module.
Args:
fs: Sampling frequency.
n_fft: The number of FFT points.
n_shift: Shift size in points.
n_mels: The number of mel basis.
win_length: Window length in points.
window: Window function type.
f_min: Minimum frequency to analyze.
f_max: Maximum frequency to analyze.
griffin_lim_iters: The number of iterations.
"""
assert check_argument_types()
self.fs = fs
self.logmel2linear = (
partial(
logmel2linear, fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
)
if n_mels is not None
else None
)
self.griffin_lim = partial(
griffin_lim,
n_fft=n_fft,
n_shift=n_shift,
win_length=win_length,
window=window,
n_iter=griffin_lim_iters,
)
self.params = dict(
n_fft=n_fft,
n_shift=n_shift,
win_length=win_length,
window=window,
n_iter=griffin_lim_iters,
)
if n_mels is not None:
self.params.update(fs=fs, n_mels=n_mels, fmin=fmin, fmax=fmax)
def __repr__(self):
retval = f"{self.__class__.__name__}("
for k, v in self.params.items():
retval += f"{k}={v}, "
retval += ")"
return retval
def __call__(self, spc: torch.Tensor) -> torch.Tensor:
"""Convert spectrogram to waveform.
Args:
spc: Log Mel filterbank (T_feats, n_mels)
or linear spectrogram (T_feats, n_fft // 2 + 1).
Returns:
Tensor: Reconstructed waveform (T_wav,).
"""
device = spc.device
dtype = spc.dtype
spc = spc.cpu().numpy()
if self.logmel2linear is not None:
spc = self.logmel2linear(spc)
wav = self.griffin_lim(spc)
return torch.tensor(wav).to(device=device, dtype=dtype)

View File

@@ -0,0 +1,103 @@
from __future__ import print_function
from multiprocessing import Pool
import argparse
from tqdm import tqdm
import math
class MultiProcessRunner:
def __init__(self, fn):
self.args = None
self.process = fn
def run(self):
parser = argparse.ArgumentParser("")
# Task-independent options
parser.add_argument("--nj", type=int, default=16)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--no_pbar", action="store_true", default=False)
parser.add_argument("--verbose", action="store_ture", default=False)
task_list, args = self.prepare(parser)
result_list = self.pool_run(task_list, args)
self.post(result_list, args)
def prepare(self, parser):
raise NotImplementedError("Please implement the prepare function.")
def post(self, result_list, args):
raise NotImplementedError("Please implement the post function.")
def pool_run(self, tasks, args):
results = []
if args.debug:
one_result = self.process(tasks[0])
results.append(one_result)
else:
pool = Pool(args.nj)
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
results.append(one_result)
pool.close()
return results
class MultiProcessRunnerV2:
def __init__(self, fn):
self.args = None
self.process = fn
def run(self):
parser = argparse.ArgumentParser("")
# Task-independent options
parser.add_argument("--nj", type=int, default=16)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--no_pbar", action="store_true", default=False)
parser.add_argument("--verbose", action="store_true", default=False)
task_list, args = self.prepare(parser)
chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
if args.verbose:
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
subtask_list = [task_list[i*chunk_size: (i+1)*chunk_size] for i in range(args.nj)]
result_list = self.pool_run(subtask_list, args)
self.post(result_list, args)
def prepare(self, parser):
raise NotImplementedError("Please implement the prepare function.")
def post(self, result_list, args):
raise NotImplementedError("Please implement the post function.")
def pool_run(self, tasks, args):
results = []
if args.debug:
one_result = self.process(tasks[0])
results.append(one_result)
else:
pool = Pool(args.nj)
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
results.append(one_result)
pool.close()
return results
class MultiProcessRunnerV3(MultiProcessRunnerV2):
def run(self):
parser = argparse.ArgumentParser("")
# Task-independent options
parser.add_argument("--nj", type=int, default=16)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--no_pbar", action="store_true", default=False)
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--sr", type=int, default=16000)
task_list, shared_param, args = self.prepare(parser)
chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
if args.verbose:
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args)
for i in range(args.nj)]
result_list = self.pool_run(subtask_list, args)
self.post(result_list, args)

View File

@@ -0,0 +1,48 @@
import io
from collections import OrderedDict
import numpy as np
def statistic_model_parameters(model, prefix=None):
var_dict = model.state_dict()
numel = 0
for i, key in enumerate(sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))):
if prefix is None or key.startswith(prefix):
numel += var_dict[key].numel()
return numel
def int2vec(x, vec_dim=8, dtype=np.int):
b = ('{:0' + str(vec_dim) + 'b}').format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
def seq2arr(seq, vec_dim=8):
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
def load_scp_as_dict(scp_path, value_type='str', kv_sep=" "):
with io.open(scp_path, 'r', encoding='utf-8') as f:
ret_dict = OrderedDict()
for one_line in f.readlines():
one_line = one_line.strip()
pos = one_line.find(kv_sep)
key, value = one_line[:pos], one_line[pos + 1:]
if value_type == 'list':
value = value.split(' ')
ret_dict[key] = value
return ret_dict
def load_scp_as_list(scp_path, value_type='str', kv_sep=" "):
with io.open(scp_path, 'r', encoding='utf8') as f:
ret_dict = []
for one_line in f.readlines():
one_line = one_line.strip()
pos = one_line.find(kv_sep)
key, value = one_line[:pos], one_line[pos + 1:]
if value_type == 'list':
value = value.split(' ')
ret_dict.append((key, value))
return ret_dict

View File

@@ -0,0 +1,35 @@
class modelscope_args():
def __init__(self,
task: str = "",
model: str = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
data_path: str = None,
output_dir: str = None,
model_revision: str = None,
dataset_type: str = "small",
batch_bins: int = 2000,
max_epoch: int = None,
accum_grad: int = None,
keep_nbest_models: int = None,
optim: str = None,
lr: float = None,
scheduler: str = None,
scheduler_conf: dict = None,
specaug: str = None,
specaug_conf: dict = None,
):
self.task = task
self.model = model
self.data_path = data_path
self.output_dir = output_dir
self.model_revision = model_revision
self.dataset_type = dataset_type
self.batch_bins = batch_bins
self.max_epoch = max_epoch
self.accum_grad = accum_grad
self.keep_nbest_models = keep_nbest_models
self.optim = optim
self.lr = lr
self.scheduler = scheduler
self.scheduler_conf = scheduler_conf
self.specaug = specaug
self.specaug_conf = specaug_conf

View File

@@ -0,0 +1,106 @@
import argparse
import copy
import yaml
class NestedDictAction(argparse.Action):
"""Action class to append items to dict object.
Examples:
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--conf', action=NestedDictAction,
... default={'a': 4})
>>> parser.parse_args(['--conf', 'a=3', '--conf', 'c=4'])
Namespace(conf={'a': 3, 'c': 4})
>>> parser.parse_args(['--conf', 'c.d=4'])
Namespace(conf={'a': 4, 'c': {'d': 4}})
>>> parser.parse_args(['--conf', 'c.d=4', '--conf', 'c=2'])
Namespace(conf={'a': 4, 'c': 2})
>>> parser.parse_args(['--conf', '{d: 5, e: 9}'])
Namespace(conf={'d': 5, 'e': 9})
"""
_syntax = """Syntax:
{op} <key>=<yaml-string>
{op} <key>.<key2>=<yaml-string>
{op} <python-dict>
{op} <yaml-string>
e.g.
{op} a=4
{op} a.b={{c: true}}
{op} {{"c": True}}
{op} {{a: 34.5}}
"""
def __init__(
self,
option_strings,
dest,
nargs=None,
default=None,
choices=None,
required=False,
help=None,
metavar=None,
):
super().__init__(
option_strings=option_strings,
dest=dest,
nargs=nargs,
default=copy.deepcopy(default),
type=None,
choices=choices,
required=required,
help=help,
metavar=metavar,
)
def __call__(self, parser, namespace, values, option_strings=None):
# --{option} a.b=3 -> {'a': {'b': 3}}
if "=" in values:
indict = copy.deepcopy(getattr(namespace, self.dest, {}))
key, value = values.split("=", maxsplit=1)
if not value.strip() == "":
value = yaml.load(value, Loader=yaml.Loader)
if not isinstance(indict, dict):
indict = {}
keys = key.split(".")
d = indict
for idx, k in enumerate(keys):
if idx == len(keys) - 1:
d[k] = value
else:
if not isinstance(d.setdefault(k, {}), dict):
# Remove the existing value and recreates as empty dict
d[k] = {}
d = d[k]
# Update the value
setattr(namespace, self.dest, indict)
else:
try:
# At the first, try eval(), i.e. Python syntax dict.
# e.g. --{option} "{'a': 3}" -> {'a': 3}
# This is workaround for internal behaviour of configargparse.
value = eval(values, {}, {})
if not isinstance(value, dict):
syntax = self._syntax.format(op=option_strings)
mes = f"must be interpreted as dict: but got {values}\n{syntax}"
raise argparse.ArgumentTypeError(self, mes)
except Exception:
# and the second, try yaml.load
value = yaml.load(values, Loader=yaml.Loader)
if not isinstance(value, dict):
syntax = self._syntax.format(op=option_strings)
mes = f"must be interpreted as dict: but got {values}\n{syntax}"
raise argparse.ArgumentError(self, mes)
d = getattr(namespace, self.dest, None)
if isinstance(d, dict):
d.update(value)
else:
# Remove existing params, and overwrite
setattr(namespace, self.dest, value)

View File

@@ -0,0 +1,245 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import string
import logging
from typing import Any, List, Union
def isChinese(ch: str):
if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039' or ch == '@':
return True
return False
def isAllChinese(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
cur = cur.replace('<unk>', '')
cur = cur.replace('<OOV>', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if isChinese(ch) is False:
return False
return True
def isAllAlpha(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
cur = cur.replace('<unk>', '')
cur = cur.replace('<OOV>', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and isChinese(ch) is True:
return False
return True
# def abbr_dispose(words: List[Any]) -> List[Any]:
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
words_size = len(words)
word_lists = []
abbr_begin = []
abbr_end = []
last_num = -1
ts_lists = []
ts_nums = []
ts_index = 0
for num in range(words_size):
if num <= last_num:
continue
if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
if num + 1 < words_size and words[
num + 1] == ' ' and num + 2 < words_size and len(
words[num +
2]) == 1 and words[num +
2].encode('utf-8').isalpha():
# found the begin of abbr
abbr_begin.append(num)
num += 2
abbr_end.append(num)
# to find the end of abbr
while True:
num += 1
if num < words_size and words[num] == ' ':
num += 1
if num < words_size and len(
words[num]) == 1 and words[num].encode(
'utf-8').isalpha():
abbr_end.pop()
abbr_end.append(num)
last_num = num
else:
break
else:
break
for num in range(words_size):
if words[num] == ' ':
ts_nums.append(ts_index)
else:
ts_nums.append(ts_index)
ts_index += 1
last_num = -1
for num in range(words_size):
if num <= last_num:
continue
if num in abbr_begin:
if time_stamp is not None:
begin = time_stamp[ts_nums[num]][0]
abbr_word = words[num].upper()
num += 1
while num < words_size:
if num in abbr_end:
abbr_word += words[num].upper()
last_num = num
break
else:
if words[num].encode('utf-8').isalpha():
abbr_word += words[num].upper()
num += 1
word_lists.append(abbr_word)
if time_stamp is not None:
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
else:
word_lists.append(words[num])
if time_stamp is not None and words[num] != ' ':
begin = time_stamp[ts_nums[num]][0]
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
begin = end
if time_stamp is not None:
return word_lists, ts_lists
else:
return word_lists
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
middle_lists = []
word_lists = []
word_item = ''
ts_lists = []
# wash words lists
for i in words:
word = ''
if isinstance(i, str):
word = i
else:
word = i.decode('utf-8')
if word in ['<s>', '</s>', '<unk>', '<OOV>']:
continue
else:
middle_lists.append(word)
# all chinese characters
if isAllChinese(middle_lists):
for i, ch in enumerate(middle_lists):
word_lists.append(ch.replace(' ', ''))
if time_stamp is not None:
ts_lists = time_stamp
# all alpha characters
elif isAllAlpha(middle_lists):
ts_flag = True
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ''
if '@@' in ch:
word = ch.replace('@@', '')
word_item += word
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
# mix characters
else:
alpha_blank = False
ts_flag = True
begin = -1
end = -1
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ''
if isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
if time_stamp is not None:
ts_flag = True
ts_lists.append([begin, end])
begin = end
elif '@@' in ch:
word = ch.replace('@@', '')
word_item += word
alpha_blank = False
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
elif isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
alpha_blank = True
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
else:
word_lists.append(ch)
if time_stamp is not None:
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
real_word_lists = []
for ch in word_lists:
if ch != ' ':
real_word_lists.append(ch)
sentence = ' '.join(real_word_lists).strip()
return sentence, ts_lists, real_word_lists
else:
word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != ' ':
real_word_lists.append(ch)
sentence = ''.join(word_lists).strip()
return sentence, real_word_lists

View File

@@ -0,0 +1,75 @@
import collections
import sys
from torch import multiprocessing
def get_size(obj, seen=None):
"""Recursively finds size of objects
Taken from https://github.com/bosswissam/pysize
"""
size = sys.getsizeof(obj)
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0
# Important mark as seen *before* entering recursion to gracefully handle
# self-referential objects
seen.add(obj_id)
if isinstance(obj, dict):
size += sum([get_size(v, seen) for v in obj.values()])
size += sum([get_size(k, seen) for k in obj.keys()])
elif hasattr(obj, "__dict__"):
size += get_size(obj.__dict__, seen)
elif isinstance(obj, (list, set, tuple)):
size += sum([get_size(i, seen) for i in obj])
return size
class SizedDict(collections.abc.MutableMapping):
def __init__(self, shared: bool = False, data: dict = None):
if data is None:
data = {}
if shared:
# NOTE(kamo): Don't set manager as a field because Manager, which includes
# weakref object, causes following error with method="spawn",
# "TypeError: can't pickle weakref objects"
self.cache = multiprocessing.Manager().dict(**data)
else:
self.manager = None
self.cache = dict(**data)
self.size = 0
def __setitem__(self, key, value):
if key in self.cache:
self.size -= get_size(self.cache[key])
else:
self.size += sys.getsizeof(key)
self.size += get_size(value)
self.cache[key] = value
def __getitem__(self, key):
return self.cache[key]
def __delitem__(self, key):
self.size -= get_size(self.cache[key])
self.size -= sys.getsizeof(key)
del self.cache[key]
def __iter__(self):
return iter(self.cache)
def __contains__(self, key):
return key in self.cache
def __len__(self):
return len(self.cache)

View File

@@ -0,0 +1,320 @@
from itertools import zip_longest
import torch
import copy
import codecs
import logging
import edit_distance
import argparse
import pdb
import numpy as np
from typing import Any, List, Tuple, Union
def ts_prediction_lfr6_standard(us_alphas,
us_peaks,
char_list,
vad_offset=0.0,
force_time_shift=-1.5,
sil_in_str=True
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
if len(us_alphas.shape) == 2:
_, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
_, peaks = us_alphas, us_peaks
num_frames = peaks.shape[0]
if char_list[-1] == '</s>':
char_list = char_list[:-1]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
num_peak = len(fire_place)
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
# char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
new_char_list.append('<sil>')
# tokens timestamp
for i in range(len(fire_place)-1):
new_char_list.append(char_list[i])
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
else:
# cut the duration to token and sil of the 0-weight frames last long
_split = fire_place[i] + MAX_TOKEN_DURATION
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
new_char_list.append('<sil>')
# tail token and end silence
# new_char_list.append(char_list[-1])
if num_frames - fire_place[-1] > START_END_THRESHOLD:
_end = (num_frames + fire_place[-1]) * 0.5
# _end = fire_place[-1]
timestamp_list[-1][1] = _end*TIME_RATE
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
new_char_list.append("<sil>")
else:
timestamp_list[-1][1] = num_frames*TIME_RATE
if vad_offset: # add offset time in model with vad
for i in range(len(timestamp_list)):
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
res_txt = ""
for char, timestamp in zip(new_char_list, timestamp_list):
#if char != '<sil>':
if not sil_in_str and char == '<sil>': continue
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
if char != '<sil>':
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
res = []
if text_postprocessed is None:
return res
if time_stamp_postprocessed is None:
return res
if len(time_stamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
if punc_id_list is None or len(punc_id_list) == 0:
res.append({
'text': text_postprocessed.split(),
"start": time_stamp_postprocessed[0][0],
"end": time_stamp_postprocessed[-1][1]
})
return res
if len(punc_id_list) != len(time_stamp_postprocessed):
print(" warning length mistach!!!!!!")
sentence_text = ''
sentence_start = time_stamp_postprocessed[0][0]
sentence_end = time_stamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None))
for punc_stamp_text in punc_stamp_text_list:
punc_id, time_stamp, text = punc_stamp_text
sentence_text += text if text is not None else ''
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
if punc_id == 2:
sentence_text += ','
res.append({
'text': sentence_text,
"start": sentence_start,
"end": sentence_end
})
sentence_text = ''
sentence_start = sentence_end
elif punc_id == 3:
sentence_text += '.'
res.append({
'text': sentence_text,
"start": sentence_start,
"end": sentence_end
})
sentence_text = ''
sentence_start = sentence_end
elif punc_id == 4:
sentence_text += '?'
res.append({
'text': sentence_text,
"start": sentence_start,
"end": sentence_end
})
sentence_text = ''
sentence_start = sentence_end
return res
class AverageShiftCalculator():
def __init__(self):
logging.warning("Calculating average shift.")
def __call__(self, file1, file2):
uttid_list1, ts_dict1 = self.read_timestamps(file1)
uttid_list2, ts_dict2 = self.read_timestamps(file2)
uttid_intersection = self._intersection(uttid_list1, uttid_list2)
res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
def _intersection(self, list1, list2):
set1 = set(list1)
set2 = set(list2)
if set1 == set2:
logging.warning("Uttid same checked.")
return set1
itsc = list(set1 & set2)
logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
return itsc
def read_timestamps(self, file):
# read timestamps file in standard format
uttid_list = []
ts_dict = {}
with codecs.open(file, 'r') as fin:
for line in fin.readlines():
text = ''
ts_list = []
line = line.rstrip()
uttid = line.split()[0]
uttid_list.append(uttid)
body = " ".join(line.split()[1:])
for pd in body.split(';'):
if not len(pd): continue
# pdb.set_trace()
char, start, end = pd.lstrip(" ").split(' ')
text += char + ','
ts_list.append((float(start), float(end)))
# ts_lists.append(ts_list)
ts_dict[uttid] = (text[:-1], ts_list)
logging.warning("File {} read done.".format(file))
return uttid_list, ts_dict
def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
shift_time = 0
for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
num_tokens = len(filtered_timestamp_list1)
return shift_time, num_tokens
def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# calculate average shift between timestamp1 and timestamp2
# when characters differ, use edit distance alignment
# and calculate the error between the same characters
self._accumlated_shift = 0
self._accumlated_tokens = 0
self.max_shift = 0
self.max_shift_uttid = None
for uttid in uttid_list:
(t1, ts1) = ts_dict1[uttid]
(t2, ts2) = ts_dict2[uttid]
_align, _align2, _align3 = [], [], []
fts1, fts2 = [], []
_t1, _t2 = [], []
sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
s = sm.get_opcodes()
for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert":
_align.append(0)
if s[j][0] == "replace" or s[j][0] == "delete":
_align3.append(0)
elif s[j][0] == "equal":
_align.append(1)
_align3.append(1)
else:
continue
# use s to index t2
for a, ts , t in zip(_align, ts2, t2.split(',')):
if a:
fts2.append(ts)
_t2.append(t)
sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
s = sm2.get_opcodes()
for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert":
_align2.append(0)
elif s[j][0] == "equal":
_align2.append(1)
else:
continue
# use s2 tp index t1
for a, ts, t in zip(_align3, ts1, t1.split(',')):
if a:
fts1.append(ts)
_t1.append(t)
if len(fts1) == len(fts2):
shift_time, num_tokens = self._shift(fts1, fts2)
self._accumlated_shift += shift_time
self._accumlated_tokens += num_tokens
if shift_time/num_tokens > self.max_shift:
self.max_shift = shift_time/num_tokens
self.max_shift_uttid = uttid
else:
logging.warning("length mismatch")
return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file):
from funasr_local.models.predictor.cif import cif_wo_hidden
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
for line1, line2 in zip(f1.readlines(), f2.readlines()):
line1 = line1.rstrip()
line2 = line2.rstrip()
assert line1.split()[0] == line2.split()[0]
uttid = line1.split()[0]
alphas = [float(i) for i in line1.split()[1:]]
new_alphas = np.array(remove_chunk_padding(alphas))
new_alphas[-1] += 1e-4
text = line2.split()[1:]
if len(text) + 1 != int(new_alphas.sum()):
# force resize
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
if " " in text:
text = text.split()
else:
text = [i for i in text]
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
force_time_shift=-7.0,
sil_in_str=False)
f3.write("{} {}\n".format(uttid, res_str))
def remove_chunk_padding(alphas):
# remove the padding part in alphas if using chunk paraformer for GPU
START_ZERO = 45
MID_ZERO = 75
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
alphas = alphas[START_ZERO:] # remove the padding at beginning
new_alphas = []
while True:
new_alphas = new_alphas + alphas[:REAL_FRAMES]
alphas = alphas[REAL_FRAMES+MID_ZERO:]
if len(alphas) < REAL_FRAMES: break
return new_alphas
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
if args.mode == 'cal_aas':
asc = AverageShiftCalculator()
asc(args.input, args.input2)
elif args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='timestamp tools')
parser.add_argument('--mode',
default=None,
type=str,
choices=SUPPORTED_MODES,
help='timestamp related toolbox')
parser.add_argument('--input', default=None, type=str, help='input file path')
parser.add_argument('--output', default=None, type=str, help='output file name')
parser.add_argument('--input2', default=None, type=str, help='input2 file path')
parser.add_argument('--kaldi-ts-type',
default='v2',
type=str,
choices=['v0', 'v1', 'v2'],
help='kaldi timestamp to write')
args = parser.parse_args()
main(args)

149
funasr_local/utils/types.py Normal file
View File

@@ -0,0 +1,149 @@
from distutils.util import strtobool
from typing import Optional
from typing import Tuple
from typing import Union
import humanfriendly
def str2bool(value: str) -> bool:
return bool(strtobool(value))
def remove_parenthesis(value: str):
value = value.strip()
if value.startswith("(") and value.endswith(")"):
value = value[1:-1]
elif value.startswith("[") and value.endswith("]"):
value = value[1:-1]
return value
def remove_quotes(value: str):
value = value.strip()
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
elif value.startswith("'") and value.endswith("'"):
value = value[1:-1]
return value
def int_or_none(value: str) -> Optional[int]:
"""int_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=int_or_none)
>>> parser.parse_args(['--foo', '456'])
Namespace(foo=456)
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return int(value)
def float_or_none(value: str) -> Optional[float]:
"""float_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=float_or_none)
>>> parser.parse_args(['--foo', '4.5'])
Namespace(foo=4.5)
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return float(value)
def humanfriendly_parse_size_or_none(value) -> Optional[float]:
if value.strip().lower() in ("none", "null", "nil"):
return None
return humanfriendly.parse_size(value)
def str_or_int(value: str) -> Union[str, int]:
try:
return int(value)
except ValueError:
return value
def str_or_none(value: str) -> Optional[str]:
"""str_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=str_or_none)
>>> parser.parse_args(['--foo', 'aaa'])
Namespace(foo='aaa')
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return value
def str2pair_str(value: str) -> Tuple[str, str]:
"""str2pair_str.
Examples:
>>> import argparse
>>> str2pair_str('abc,def ')
('abc', 'def')
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=str2pair_str)
>>> parser.parse_args(['--foo', 'abc,def'])
Namespace(foo=('abc', 'def'))
"""
value = remove_parenthesis(value)
a, b = value.split(",")
# Workaround for configargparse issues:
# If the list values are given from yaml file,
# the value givent to type() is shaped as python-list,
# e.g. ['a', 'b', 'c'],
# so we need to remove double quotes from it.
return remove_quotes(a), remove_quotes(b)
def str2triple_str(value: str) -> Tuple[str, str, str]:
"""str2triple_str.
Examples:
>>> str2triple_str('abc,def ,ghi')
('abc', 'def', 'ghi')
"""
value = remove_parenthesis(value)
a, b, c = value.split(",")
# Workaround for configargparse issues:
# If the list values are given from yaml file,
# the value givent to type() is shaped as python-list,
# e.g. ['a', 'b', 'c'],
# so we need to remove quotes from it.
return remove_quotes(a), remove_quotes(b), remove_quotes(c)

View File

@@ -0,0 +1,321 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import shutil
from multiprocessing import Pool
from typing import Any, Dict, Union
# import kaldiio
import librosa
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
def ndarray_resample(audio_in: np.ndarray,
fs_in: int = 16000,
fs_out: int = 16000) -> np.ndarray:
audio_out = audio_in
if fs_in != fs_out:
audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
return audio_out
def torch_resample(audio_in: torch.Tensor,
fs_in: int = 16000,
fs_out: int = 16000) -> torch.Tensor:
audio_out = audio_in
if fs_in != fs_out:
audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
new_freq=fs_out)(audio_in)
return audio_out
def extract_CMVN_featrures(mvn_file):
"""
extract CMVN from cmvn.ark
"""
if not os.path.exists(mvn_file):
return None
try:
cmvn = kaldiio.load_mat(mvn_file)
means = []
variance = []
for i in range(cmvn.shape[1] - 1):
means.append(float(cmvn[0][i]))
count = float(cmvn[0][-1])
for i in range(cmvn.shape[1] - 1):
variance.append(float(cmvn[1][i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
except Exception:
cmvn = extract_CMVN_features_txt(mvn_file)
return cmvn
def extract_CMVN_features_txt(mvn_file): # noqa
with open(mvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
add_shift_list = []
rescale_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
add_shift_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
rescale_list = list(rescale_line)
continue
add_shift_list_f = [float(s) for s in add_shift_list]
rescale_list_f = [float(s) for s in rescale_list]
cmvn = np.array([add_shift_list_f, rescale_list_f])
return cmvn
def build_LFR_features(inputs, m=7, n=6): # noqa
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
Args:
inputs_batch: inputs is T x D np.ndarray
m: number of frames to stack
n: number of frames to skip
"""
# LFR_inputs_batch = []
# for inputs in inputs_batch:
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / n))
left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (m - 1) // 2
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
else: # process last LFR frame
num_padding = m - (T - i * n)
frame = np.hstack(inputs[i * n:])
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)
def compute_fbank(wav_file,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
is_pcm=False,
fs: Union[int, Dict[Any, int]] = 16000):
audio_sr: int = 16000
model_sr: int = 16000
if isinstance(fs, int):
model_sr = fs
audio_sr = fs
else:
model_sr = fs['model_fs']
audio_sr = fs['audio_fs']
if is_pcm is True:
# byte(PCM16) to float32, and resample
value = wav_file
middle_data = np.frombuffer(value, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
waveform = np.frombuffer(
(middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
waveform = ndarray_resample(waveform, audio_sr, model_sr)
waveform = torch.from_numpy(waveform.reshape(1, -1))
else:
# load pcm from wav, and resample
waveform, audio_sr = torchaudio.load(wav_file)
waveform = waveform * (1 << 15)
waveform = torch_resample(waveform, audio_sr, model_sr)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type='hamming',
sample_frequency=model_sr)
input_feats = mat
return input_feats
def wav2num_frame(wav_path, frontend_conf):
waveform, sampling_rate = torchaudio.load(wav_path)
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
return n_frames, feature_dim, speech_length
def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_max, idx):
wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
with open(wav_scp_file) as f:
lines = f.readlines()
with open(shape_file, "w") as f:
for line in lines:
sample_name, wav_path = line.strip().split()
n_frames, feature_dim, speech_length = wav2num_frame(wav_path, frontend_conf)
write_flag = True
if speech_length_min > 0 and speech_length < speech_length_min:
write_flag = False
if speech_length_max > 0 and speech_length > speech_length_max:
write_flag = False
if write_flag:
f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
f.flush()
def calc_shape(data_dir, dataset, frontend_conf, speech_length_min=-1, speech_length_max=-1, nj=32):
shape_path = os.path.join(data_dir, dataset, "shape_files")
if os.path.exists(shape_path):
assert os.path.exists(os.path.join(data_dir, dataset, "speech_shape"))
print('Shape file for small dataset already exists.')
return
os.makedirs(shape_path, exist_ok=True)
# split
wav_scp_file = os.path.join(data_dir, dataset, "wav.scp")
with open(wav_scp_file) as f:
lines = f.readlines()
num_lines = len(lines)
num_job_lines = num_lines // nj
start = 0
for i in range(nj):
end = start + num_job_lines
file = os.path.join(shape_path, "wav.scp.{}".format(str(i + 1)))
with open(file, "w") as f:
if i == nj - 1:
f.writelines(lines[start:])
else:
f.writelines(lines[start:end])
start = end
p = Pool(nj)
for i in range(nj):
p.apply_async(calc_shape_core,
args=(shape_path, frontend_conf, speech_length_min, speech_length_max, str(i + 1)))
print('Generating shape files, please wait a few minutes...')
p.close()
p.join()
# combine
file = os.path.join(data_dir, dataset, "speech_shape")
with open(file, "w") as f:
for i in range(nj):
job_file = os.path.join(shape_path, "speech_shape.{}".format(str(i + 1)))
with open(job_file) as job_f:
lines = job_f.readlines()
f.writelines(lines)
print('Generating shape files done.')
def generate_data_list(data_dir, dataset, nj=100):
split_dir = os.path.join(data_dir, dataset, "split")
if os.path.exists(split_dir):
assert os.path.exists(os.path.join(data_dir, dataset, "data.list"))
print('Data list for large dataset already exists.')
return
os.makedirs(split_dir, exist_ok=True)
with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
wav_lines = f_wav.readlines()
with open(os.path.join(data_dir, dataset, "text")) as f_text:
text_lines = f_text.readlines()
total_num_lines = len(wav_lines)
num_lines = total_num_lines // nj
start_num = 0
for i in range(nj):
end_num = start_num + num_lines
split_dir_nj = os.path.join(split_dir, str(i + 1))
os.mkdir(split_dir_nj)
wav_file = os.path.join(split_dir_nj, 'wav.scp')
text_file = os.path.join(split_dir_nj, "text")
with open(wav_file, "w") as fw, open(text_file, "w") as ft:
if i == nj - 1:
fw.writelines(wav_lines[start_num:])
ft.writelines(text_lines[start_num:])
else:
fw.writelines(wav_lines[start_num:end_num])
ft.writelines(text_lines[start_num:end_num])
start_num = end_num
data_list_file = os.path.join(data_dir, dataset, "data.list")
with open(data_list_file, "w") as f_data:
for i in range(nj):
wav_path = os.path.join(split_dir, str(i + 1), "wav.scp")
text_path = os.path.join(split_dir, str(i + 1), "text")
f_data.write(wav_path + " " + text_path + "\n")
def filter_wav_text(data_dir, dataset):
wav_file = os.path.join(data_dir,dataset,"wav.scp")
text_file = os.path.join(data_dir, dataset, "text")
with open(wav_file) as f_wav, open(text_file) as f_text:
wav_lines = f_wav.readlines()
text_lines = f_text.readlines()
os.rename(wav_file, "{}.bak".format(wav_file))
os.rename(text_file, "{}.bak".format(text_file))
wav_dict = {}
for line in wav_lines:
parts = line.strip().split()
if len(parts) != 2:
continue
sample_name, wav_path = parts
wav_dict[sample_name] = wav_path
text_dict = {}
for line in text_lines:
parts = line.strip().split()
if len(parts) < 2:
continue
sample_name = parts[0]
text_dict[sample_name] = " ".join(parts[1:]).lower()
filter_count = 0
with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
for sample_name, wav_path in wav_dict.items():
if sample_name in text_dict.keys():
f_wav.write(sample_name + " " + wav_path + "\n")
f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
else:
filter_count += 1
print("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), filter_count, dataset))

View File

@@ -0,0 +1,14 @@
import yaml
class NoAliasSafeDumper(yaml.SafeDumper):
# Disable anchor/alias in yaml because looks ugly
def ignore_aliases(self, data):
return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
"""Safe-dump in yaml with no anchor/alias"""
return yaml.dump(
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
)