mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/utils/__init__.py
Normal file
0
funasr_local/utils/__init__.py
Normal file
85
funasr_local/utils/asr_env_checking.py
Normal file
85
funasr_local/utils/asr_env_checking.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import ssl
|
||||
|
||||
import nltk
|
||||
|
||||
# mkdir nltk_data dir if not exist
|
||||
try:
|
||||
nltk.data.find('.')
|
||||
except LookupError:
|
||||
dir_list = nltk.data.path
|
||||
for dir_item in dir_list:
|
||||
if not os.path.exists(dir_item):
|
||||
os.mkdir(dir_item)
|
||||
if os.path.exists(dir_item):
|
||||
break
|
||||
|
||||
# download one package if nltk_data not exist
|
||||
try:
|
||||
nltk.data.find('.')
|
||||
except: # noqa: *
|
||||
try:
|
||||
_create_unverified_https_context = ssl._create_unverified_context
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
ssl._create_default_https_context = _create_unverified_https_context
|
||||
|
||||
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
|
||||
|
||||
# deploy taggers/averaged_perceptron_tagger
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger')
|
||||
except: # noqa: *
|
||||
data_dir = nltk.data.find('.')
|
||||
target_dir = os.path.join(data_dir, 'taggers')
|
||||
if not os.path.exists(target_dir):
|
||||
os.mkdir(target_dir)
|
||||
src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
|
||||
'averaged_perceptron_tagger.zip')
|
||||
shutil.copyfile(src_file,
|
||||
os.path.join(target_dir, 'averaged_perceptron_tagger.zip'))
|
||||
shutil._unpack_zipfile(
|
||||
os.path.join(target_dir, 'averaged_perceptron_tagger.zip'), target_dir)
|
||||
|
||||
# deploy corpora/cmudict
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict')
|
||||
except: # noqa: *
|
||||
data_dir = nltk.data.find('.')
|
||||
target_dir = os.path.join(data_dir, 'corpora')
|
||||
if not os.path.exists(target_dir):
|
||||
os.mkdir(target_dir)
|
||||
src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
|
||||
'cmudict.zip')
|
||||
shutil.copyfile(src_file, os.path.join(target_dir, 'cmudict.zip'))
|
||||
shutil._unpack_zipfile(os.path.join(target_dir, 'cmudict.zip'), target_dir)
|
||||
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger')
|
||||
except: # noqa: *
|
||||
try:
|
||||
_create_unverified_https_context = ssl._create_unverified_context
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
ssl._create_default_https_context = _create_unverified_https_context
|
||||
|
||||
nltk.download('averaged_perceptron_tagger',
|
||||
halt_on_error=False,
|
||||
raise_on_error=True)
|
||||
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict')
|
||||
except: # noqa: *
|
||||
try:
|
||||
_create_unverified_https_context = ssl._create_unverified_context
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
ssl._create_default_https_context = _create_unverified_https_context
|
||||
|
||||
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
|
||||
355
funasr_local/utils/asr_utils.py
Normal file
355
funasr_local/utils/asr_utils.py
Normal file
@@ -0,0 +1,355 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import struct
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
green_color = '\033[1;32m'
|
||||
red_color = '\033[0;31;40m'
|
||||
yellow_color = '\033[0;33;40m'
|
||||
end_color = '\033[0m'
|
||||
|
||||
global_asr_language = 'zh-cn'
|
||||
|
||||
SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
|
||||
|
||||
def get_version():
|
||||
return float(pkg_resources.get_distribution('easyasr').version)
|
||||
|
||||
|
||||
def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
|
||||
r_audio_fs = None
|
||||
|
||||
if audio_format == 'wav' or audio_format == 'scp':
|
||||
r_audio_fs = get_sr_from_wav(audio_in)
|
||||
elif audio_format == 'pcm' and isinstance(audio_in, bytes):
|
||||
r_audio_fs = get_sr_from_bytes(audio_in)
|
||||
|
||||
return r_audio_fs
|
||||
|
||||
|
||||
def type_checking(audio_in: Union[str, bytes],
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None):
|
||||
r_recog_type = recog_type
|
||||
r_audio_format = audio_format
|
||||
r_wav_path = audio_in
|
||||
|
||||
if isinstance(audio_in, str):
|
||||
assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
|
||||
elif isinstance(audio_in, bytes):
|
||||
assert len(audio_in) > 0, 'audio in is empty'
|
||||
r_audio_format = 'pcm'
|
||||
r_recog_type = 'wav'
|
||||
|
||||
if audio_in is None:
|
||||
# for raw_inputs
|
||||
r_recog_type = 'wav'
|
||||
r_audio_format = 'pcm'
|
||||
|
||||
if r_recog_type is None and audio_in is not None:
|
||||
# audio_in is wav, recog_type is wav_file
|
||||
if os.path.isfile(audio_in):
|
||||
audio_type = os.path.basename(audio_in).lower()
|
||||
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
|
||||
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
|
||||
r_recog_type = 'wav'
|
||||
r_audio_format = 'wav'
|
||||
if audio_type.rfind(".scp") >= 0:
|
||||
r_recog_type = 'wav'
|
||||
r_audio_format = 'scp'
|
||||
if r_recog_type is None:
|
||||
raise NotImplementedError(
|
||||
f'Not supported audio type: {audio_type}')
|
||||
|
||||
# recog_type is datasets_file
|
||||
elif os.path.isdir(audio_in):
|
||||
dir_name = os.path.basename(audio_in)
|
||||
if 'test' in dir_name:
|
||||
r_recog_type = 'test'
|
||||
elif 'dev' in dir_name:
|
||||
r_recog_type = 'dev'
|
||||
elif 'train' in dir_name:
|
||||
r_recog_type = 'train'
|
||||
|
||||
if r_audio_format is None:
|
||||
if find_file_by_ends(audio_in, '.ark'):
|
||||
r_audio_format = 'kaldi_ark'
|
||||
elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
|
||||
audio_in, '.WAV'):
|
||||
r_audio_format = 'wav'
|
||||
elif find_file_by_ends(audio_in, '.records'):
|
||||
r_audio_format = 'tfrecord'
|
||||
|
||||
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
|
||||
# datasets with kaldi_ark file
|
||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
|
||||
elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
|
||||
# datasets with tensorflow records file
|
||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
|
||||
elif r_audio_format == 'wav' and r_recog_type != 'wav':
|
||||
# datasets with waveform files
|
||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
|
||||
|
||||
return r_recog_type, r_audio_format, r_wav_path
|
||||
|
||||
|
||||
def get_sr_from_bytes(wav: bytes):
|
||||
sr = None
|
||||
data = wav
|
||||
if len(data) > 44:
|
||||
try:
|
||||
header_fields = {}
|
||||
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
|
||||
header_fields['Format'] = str(data[8:12], 'UTF-8')
|
||||
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
|
||||
if header_fields['ChunkID'] == 'RIFF' and header_fields[
|
||||
'Format'] == 'WAVE' and header_fields[
|
||||
'Subchunk1ID'] == 'fmt ':
|
||||
header_fields['SampleRate'] = struct.unpack('<I',
|
||||
data[24:28])[0]
|
||||
sr = header_fields['SampleRate']
|
||||
except Exception:
|
||||
# no treatment
|
||||
pass
|
||||
else:
|
||||
logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
|
||||
|
||||
return sr
|
||||
|
||||
|
||||
def get_sr_from_wav(fname: str):
|
||||
fs = None
|
||||
if os.path.isfile(fname):
|
||||
audio_type = os.path.basename(fname).lower()
|
||||
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
|
||||
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
|
||||
if support_audio_type == "pcm":
|
||||
fs = None
|
||||
else:
|
||||
audio, fs = torchaudio.load(fname)
|
||||
break
|
||||
if audio_type.rfind(".scp") >= 0:
|
||||
with open(fname, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
wav_path = line.split()[1]
|
||||
fs = get_sr_from_wav(wav_path)
|
||||
if fs is not None:
|
||||
break
|
||||
return fs
|
||||
elif os.path.isdir(fname):
|
||||
dir_files = os.listdir(fname)
|
||||
for file in dir_files:
|
||||
file_path = os.path.join(fname, file)
|
||||
if os.path.isfile(file_path):
|
||||
fs = get_sr_from_wav(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
fs = get_sr_from_wav(file_path)
|
||||
|
||||
if fs is not None:
|
||||
break
|
||||
|
||||
return fs
|
||||
|
||||
|
||||
def find_file_by_ends(dir_path: str, ends: str):
|
||||
dir_files = os.listdir(dir_path)
|
||||
for file in dir_files:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if ends == ".wav" or ends == ".WAV":
|
||||
audio_type = os.path.basename(file_path).lower()
|
||||
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
|
||||
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
|
||||
return True
|
||||
raise NotImplementedError(
|
||||
f'Not supported audio type: {audio_type}')
|
||||
elif file_path.endswith(ends):
|
||||
return True
|
||||
elif os.path.isdir(file_path):
|
||||
if find_file_by_ends(file_path, ends):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
|
||||
dir_files = os.listdir(dir_path)
|
||||
for file in dir_files:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
audio_type = os.path.basename(file_path).lower()
|
||||
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
|
||||
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
|
||||
wav_list.append(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
recursion_dir_all_wav(wav_list, file_path)
|
||||
|
||||
return wav_list
|
||||
|
||||
def compute_wer(hyp_list: List[Any],
|
||||
ref_list: List[Any],
|
||||
lang: str = None) -> Dict[str, Any]:
|
||||
assert len(hyp_list) > 0, 'hyp list is empty'
|
||||
assert len(ref_list) > 0, 'ref list is empty'
|
||||
|
||||
rst = {
|
||||
'Wrd': 0,
|
||||
'Corr': 0,
|
||||
'Ins': 0,
|
||||
'Del': 0,
|
||||
'Sub': 0,
|
||||
'Snt': 0,
|
||||
'Err': 0.0,
|
||||
'S.Err': 0.0,
|
||||
'wrong_words': 0,
|
||||
'wrong_sentences': 0
|
||||
}
|
||||
|
||||
if lang is None:
|
||||
lang = global_asr_language
|
||||
|
||||
for h_item in hyp_list:
|
||||
for r_item in ref_list:
|
||||
if h_item['key'] == r_item['key']:
|
||||
out_item = compute_wer_by_line(h_item['value'],
|
||||
r_item['value'],
|
||||
lang)
|
||||
rst['Wrd'] += out_item['nwords']
|
||||
rst['Corr'] += out_item['cor']
|
||||
rst['wrong_words'] += out_item['wrong']
|
||||
rst['Ins'] += out_item['ins']
|
||||
rst['Del'] += out_item['del']
|
||||
rst['Sub'] += out_item['sub']
|
||||
rst['Snt'] += 1
|
||||
if out_item['wrong'] > 0:
|
||||
rst['wrong_sentences'] += 1
|
||||
print_wrong_sentence(key=h_item['key'],
|
||||
hyp=h_item['value'],
|
||||
ref=r_item['value'])
|
||||
else:
|
||||
print_correct_sentence(key=h_item['key'],
|
||||
hyp=h_item['value'],
|
||||
ref=r_item['value'])
|
||||
|
||||
break
|
||||
|
||||
if rst['Wrd'] > 0:
|
||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
|
||||
if rst['Snt'] > 0:
|
||||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
|
||||
|
||||
return rst
|
||||
|
||||
|
||||
def compute_wer_by_line(hyp: List[str],
|
||||
ref: List[str],
|
||||
lang: str = 'zh-cn') -> Dict[str, Any]:
|
||||
if lang != 'zh-cn':
|
||||
hyp = hyp.split()
|
||||
ref = ref.split()
|
||||
|
||||
hyp = list(map(lambda x: x.lower(), hyp))
|
||||
ref = list(map(lambda x: x.lower(), ref))
|
||||
|
||||
len_hyp = len(hyp)
|
||||
len_ref = len(ref)
|
||||
|
||||
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
|
||||
|
||||
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
|
||||
|
||||
for i in range(len_hyp + 1):
|
||||
cost_matrix[i][0] = i
|
||||
for j in range(len_ref + 1):
|
||||
cost_matrix[0][j] = j
|
||||
|
||||
for i in range(1, len_hyp + 1):
|
||||
for j in range(1, len_ref + 1):
|
||||
if hyp[i - 1] == ref[j - 1]:
|
||||
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
|
||||
else:
|
||||
substitution = cost_matrix[i - 1][j - 1] + 1
|
||||
insertion = cost_matrix[i - 1][j] + 1
|
||||
deletion = cost_matrix[i][j - 1] + 1
|
||||
|
||||
compare_val = [substitution, insertion, deletion]
|
||||
|
||||
min_val = min(compare_val)
|
||||
operation_idx = compare_val.index(min_val) + 1
|
||||
cost_matrix[i][j] = min_val
|
||||
ops_matrix[i][j] = operation_idx
|
||||
|
||||
match_idx = []
|
||||
i = len_hyp
|
||||
j = len_ref
|
||||
rst = {
|
||||
'nwords': len_ref,
|
||||
'cor': 0,
|
||||
'wrong': 0,
|
||||
'ins': 0,
|
||||
'del': 0,
|
||||
'sub': 0
|
||||
}
|
||||
while i >= 0 or j >= 0:
|
||||
i_idx = max(0, i)
|
||||
j_idx = max(0, j)
|
||||
|
||||
if ops_matrix[i_idx][j_idx] == 0: # correct
|
||||
if i - 1 >= 0 and j - 1 >= 0:
|
||||
match_idx.append((j - 1, i - 1))
|
||||
rst['cor'] += 1
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 2: # insert
|
||||
i -= 1
|
||||
rst['ins'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 3: # delete
|
||||
j -= 1
|
||||
rst['del'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 1: # substitute
|
||||
i -= 1
|
||||
j -= 1
|
||||
rst['sub'] += 1
|
||||
|
||||
if i < 0 and j >= 0:
|
||||
rst['del'] += 1
|
||||
elif j < 0 and i >= 0:
|
||||
rst['ins'] += 1
|
||||
|
||||
match_idx.reverse()
|
||||
wrong_cnt = cost_matrix[len_hyp][len_ref]
|
||||
rst['wrong'] = wrong_cnt
|
||||
|
||||
return rst
|
||||
|
||||
|
||||
def print_wrong_sentence(key: str, hyp: str, ref: str):
|
||||
space = len(key)
|
||||
print(key + yellow_color + ' ref: ' + ref)
|
||||
print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
|
||||
|
||||
|
||||
def print_correct_sentence(key: str, hyp: str, ref: str):
|
||||
space = len(key)
|
||||
print(key + yellow_color + ' ref: ' + ref)
|
||||
print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
|
||||
|
||||
|
||||
def print_progress(percent):
|
||||
if percent > 1:
|
||||
percent = 1
|
||||
res = int(50 * percent) * '#'
|
||||
print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')
|
||||
17
funasr_local/utils/build_dataclass.py
Normal file
17
funasr_local/utils/build_dataclass.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
from typeguard import check_type
|
||||
|
||||
|
||||
def build_dataclass(dataclass, args: argparse.Namespace):
|
||||
"""Helper function to build dataclass from 'args'."""
|
||||
kwargs = {}
|
||||
for field in dataclasses.fields(dataclass):
|
||||
if not hasattr(args, field.name):
|
||||
raise ValueError(
|
||||
f"args doesn't have {field.name}. You need to set it to ArgumentsParser"
|
||||
)
|
||||
check_type(field.name, getattr(args, field.name), field.type)
|
||||
kwargs[field.name] = getattr(args, field.name)
|
||||
return dataclass(**kwargs)
|
||||
65
funasr_local/utils/cli_utils.py
Normal file
65
funasr_local/utils/cli_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Sequence
|
||||
from distutils.util import strtobool as dist_strtobool
|
||||
import sys
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
def strtobool(x):
|
||||
# distutils.util.strtobool returns integer, but it's confusing,
|
||||
return bool(dist_strtobool(x))
|
||||
|
||||
|
||||
def get_commandline_args():
|
||||
extra_chars = [
|
||||
" ",
|
||||
";",
|
||||
"&",
|
||||
"(",
|
||||
")",
|
||||
"|",
|
||||
"^",
|
||||
"<",
|
||||
">",
|
||||
"?",
|
||||
"*",
|
||||
"[",
|
||||
"]",
|
||||
"$",
|
||||
"`",
|
||||
'"',
|
||||
"\\",
|
||||
"!",
|
||||
"{",
|
||||
"}",
|
||||
]
|
||||
|
||||
# Escape the extra characters for shell
|
||||
argv = [
|
||||
arg.replace("'", "'\\''")
|
||||
if all(char not in arg for char in extra_chars)
|
||||
else "'" + arg.replace("'", "'\\''") + "'"
|
||||
for arg in sys.argv
|
||||
]
|
||||
|
||||
return sys.executable + " " + " ".join(argv)
|
||||
|
||||
|
||||
def is_scipy_wav_style(value):
|
||||
# If Tuple[int, numpy.ndarray] or not
|
||||
return (
|
||||
isinstance(value, Sequence)
|
||||
and len(value) == 2
|
||||
and isinstance(value[0], int)
|
||||
and isinstance(value[1], numpy.ndarray)
|
||||
)
|
||||
|
||||
|
||||
def assert_scipy_wav_style(value):
|
||||
assert is_scipy_wav_style(
|
||||
value
|
||||
), "Must be Tuple[int, numpy.ndarray], but got {}".format(
|
||||
type(value)
|
||||
if not isinstance(value, Sequence)
|
||||
else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value))
|
||||
)
|
||||
59
funasr_local/utils/compute_eer.py
Normal file
59
funasr_local/utils/compute_eer.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
from sklearn.metrics import roc_curve
|
||||
import argparse
|
||||
|
||||
|
||||
def _compute_eer(label, pred, positive_label=1):
|
||||
"""
|
||||
Python compute equal error rate (eer)
|
||||
ONLY tested on binary classification
|
||||
|
||||
:param label: ground-truth label, should be a 1-d list or np.array, each element represents the ground-truth label of one sample
|
||||
:param pred: model prediction, should be a 1-d list or np.array, each element represents the model prediction of one sample
|
||||
:param positive_label: the class that is viewed as positive class when computing EER
|
||||
:return: equal error rate (EER)
|
||||
"""
|
||||
|
||||
# all fpr, tpr, fnr, fnr, threshold are lists (in the format of np.array)
|
||||
fpr, tpr, threshold = roc_curve(label, pred, pos_label=positive_label)
|
||||
fnr = 1 - tpr
|
||||
|
||||
# the threshold of fnr == fpr
|
||||
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
|
||||
|
||||
# theoretically eer from fpr and eer from fnr should be identical but they can be slightly differ in reality
|
||||
eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
||||
eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
|
||||
|
||||
# return the mean of eer from fpr and from fnr
|
||||
eer = (eer_1 + eer_2) / 2
|
||||
return eer, eer_threshold
|
||||
|
||||
|
||||
def compute_eer(trials_path, scores_path):
|
||||
labels = []
|
||||
for one_line in open(trials_path, "r"):
|
||||
labels.append(one_line.strip().rsplit(" ", 1)[-1] == "target")
|
||||
labels = np.array(labels, dtype=int)
|
||||
|
||||
scores = []
|
||||
for one_line in open(scores_path, "r"):
|
||||
scores.append(float(one_line.strip().rsplit(" ", 1)[-1]))
|
||||
scores = np.array(scores, dtype=float)
|
||||
|
||||
eer, threshold = _compute_eer(labels, scores)
|
||||
return eer, threshold
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("trials", help="trial list")
|
||||
parser.add_argument("scores", help="score file, normalized to [0, 1]")
|
||||
args = parser.parse_args()
|
||||
|
||||
eer, threshold = compute_eer(args.trials, args.scores)
|
||||
print("EER is {:.4f} at threshold {:.4f}".format(eer * 100.0, threshold))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
159
funasr_local/utils/compute_min_dcf.py
Normal file
159
funasr_local/utils/compute_min_dcf.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2018 David Snyder
|
||||
# Apache 2.0
|
||||
|
||||
# This script computes the minimum detection cost function, which is a common
|
||||
# error metric used in speaker recognition. Compared to equal error-rate,
|
||||
# which assigns equal weight to false negatives and false positives, this
|
||||
# error-rate is usually used to assess performance in settings where achieving
|
||||
# a low false positive rate is more important than achieving a low false
|
||||
# negative rate. See the NIST 2016 Speaker Recognition Evaluation Plan at
|
||||
# https://www.nist.gov/sites/default/files/documents/2016/10/07/sre16_eval_plan_v1.3.pdf
|
||||
# for more details about the metric.
|
||||
from __future__ import print_function
|
||||
from operator import itemgetter
|
||||
import sys, argparse, os
|
||||
|
||||
|
||||
def GetArgs():
|
||||
parser = argparse.ArgumentParser(description="Compute the minimum "
|
||||
"detection cost function along with the threshold at which it occurs. "
|
||||
"Usage: sid/compute_min_dcf.py [options...] <scores-file> "
|
||||
"<trials-file> "
|
||||
"E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
|
||||
"exp/scores/trials data/test/trials",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--p-target', type=float, dest="p_target",
|
||||
default=0.01,
|
||||
help='The prior probability of the target speaker in a trial.')
|
||||
parser.add_argument('--c-miss', type=float, dest="c_miss", default=1,
|
||||
help='Cost of a missed detection. This is usually not changed.')
|
||||
parser.add_argument('--c-fa', type=float, dest="c_fa", default=1,
|
||||
help='Cost of a spurious detection. This is usually not changed.')
|
||||
parser.add_argument("scores_filename",
|
||||
help="Input scores file, with columns of the form "
|
||||
"<utt1> <utt2> <score>")
|
||||
parser.add_argument("trials_filename",
|
||||
help="Input trials file, with columns of the form "
|
||||
"<utt1> <utt2> <target/nontarget>")
|
||||
sys.stderr.write(' '.join(sys.argv) + "\n")
|
||||
args = parser.parse_args()
|
||||
args = CheckArgs(args)
|
||||
return args
|
||||
|
||||
|
||||
def CheckArgs(args):
|
||||
if args.c_fa <= 0:
|
||||
raise Exception("--c-fa must be greater than 0")
|
||||
if args.c_miss <= 0:
|
||||
raise Exception("--c-miss must be greater than 0")
|
||||
if args.p_target <= 0 or args.p_target >= 1:
|
||||
raise Exception("--p-target must be greater than 0 and less than 1")
|
||||
return args
|
||||
|
||||
|
||||
# Creates a list of false-negative rates, a list of false-positive rates
|
||||
# and a list of decision thresholds that give those error-rates.
|
||||
def ComputeErrorRates(scores, labels):
|
||||
|
||||
# Sort the scores from smallest to largest, and also get the corresponding
|
||||
# indexes of the sorted scores. We will treat the sorted scores as the
|
||||
# thresholds at which the the error-rates are evaluated.
|
||||
sorted_indexes, thresholds = zip(*sorted(
|
||||
[(index, threshold) for index, threshold in enumerate(scores)],
|
||||
key=itemgetter(1)))
|
||||
labels = [labels[i] for i in sorted_indexes]
|
||||
fns = []
|
||||
tns = []
|
||||
|
||||
# At the end of this loop, fns[i] is the number of errors made by
|
||||
# incorrectly rejecting scores less than thresholds[i]. And, tns[i]
|
||||
# is the total number of times that we have correctly rejected scores
|
||||
# less than thresholds[i].
|
||||
for i in range(0, len(labels)):
|
||||
if i == 0:
|
||||
fns.append(labels[i])
|
||||
tns.append(1 - labels[i])
|
||||
else:
|
||||
fns.append(fns[i-1] + labels[i])
|
||||
tns.append(tns[i-1] + 1 - labels[i])
|
||||
positives = sum(labels)
|
||||
negatives = len(labels) - positives
|
||||
|
||||
# Now divide the false negatives by the total number of
|
||||
# positives to obtain the false negative rates across
|
||||
# all thresholds
|
||||
fnrs = [fn / float(positives) for fn in fns]
|
||||
|
||||
# Divide the true negatives by the total number of
|
||||
# negatives to get the true negative rate. Subtract these
|
||||
# quantities from 1 to get the false positive rates.
|
||||
fprs = [1 - tn / float(negatives) for tn in tns]
|
||||
return fnrs, fprs, thresholds
|
||||
|
||||
|
||||
# Computes the minimum of the detection cost function. The comments refer to
|
||||
# equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
|
||||
def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
|
||||
min_c_det = float("inf")
|
||||
min_c_det_threshold = thresholds[0]
|
||||
for i in range(0, len(fnrs)):
|
||||
# See Equation (2). it is a weighted sum of false negative
|
||||
# and false positive errors.
|
||||
c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
|
||||
if c_det < min_c_det:
|
||||
min_c_det = c_det
|
||||
min_c_det_threshold = thresholds[i]
|
||||
# See Equations (3) and (4). Now we normalize the cost.
|
||||
c_def = min(c_miss * p_target, c_fa * (1 - p_target))
|
||||
min_dcf = min_c_det / c_def
|
||||
return min_dcf, min_c_det_threshold
|
||||
|
||||
|
||||
def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
|
||||
scores_file = open(scores_filename, 'r').readlines()
|
||||
trials_file = open(trials_filename, 'r').readlines()
|
||||
c_miss = c_miss
|
||||
c_fa = c_fa
|
||||
p_target = p_target
|
||||
|
||||
scores = []
|
||||
labels = []
|
||||
|
||||
trials = {}
|
||||
for line in trials_file:
|
||||
utt1, utt2, target = line.rstrip().split()
|
||||
trial = utt1 + " " + utt2
|
||||
trials[trial] = target
|
||||
|
||||
for line in scores_file:
|
||||
utt1, utt2, score = line.rstrip().split()
|
||||
trial = utt1 + " " + utt2
|
||||
if trial in trials:
|
||||
scores.append(float(score))
|
||||
if trials[trial] == "target":
|
||||
labels.append(1)
|
||||
else:
|
||||
labels.append(0)
|
||||
else:
|
||||
raise Exception("Missing entry for " + utt1 + " and " + utt2
|
||||
+ " " + scores_filename)
|
||||
|
||||
fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
|
||||
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target,
|
||||
c_miss, c_fa)
|
||||
return mindcf, threshold
|
||||
|
||||
|
||||
def main():
|
||||
args = GetArgs()
|
||||
mindcf, threshold = compute_min_dcf(
|
||||
args.scores_filename, args.trials_filename,
|
||||
args.c_miss, args.c_fa, args.p_target
|
||||
)
|
||||
sys.stdout.write("minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
|
||||
"c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
157
funasr_local/utils/compute_wer.py
Normal file
157
funasr_local/utils/compute_wer.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
def compute_wer(ref_file,
|
||||
hyp_file,
|
||||
cer_detail_file):
|
||||
rst = {
|
||||
'Wrd': 0,
|
||||
'Corr': 0,
|
||||
'Ins': 0,
|
||||
'Del': 0,
|
||||
'Sub': 0,
|
||||
'Snt': 0,
|
||||
'Err': 0.0,
|
||||
'S.Err': 0.0,
|
||||
'wrong_words': 0,
|
||||
'wrong_sentences': 0
|
||||
}
|
||||
|
||||
hyp_dict = {}
|
||||
ref_dict = {}
|
||||
with open(hyp_file, 'r') as hyp_reader:
|
||||
for line in hyp_reader:
|
||||
key = line.strip().split()[0]
|
||||
value = line.strip().split()[1:]
|
||||
hyp_dict[key] = value
|
||||
with open(ref_file, 'r') as ref_reader:
|
||||
for line in ref_reader:
|
||||
key = line.strip().split()[0]
|
||||
value = line.strip().split()[1:]
|
||||
ref_dict[key] = value
|
||||
|
||||
cer_detail_writer = open(cer_detail_file, 'w')
|
||||
for hyp_key in hyp_dict:
|
||||
if hyp_key in ref_dict:
|
||||
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
|
||||
rst['Wrd'] += out_item['nwords']
|
||||
rst['Corr'] += out_item['cor']
|
||||
rst['wrong_words'] += out_item['wrong']
|
||||
rst['Ins'] += out_item['ins']
|
||||
rst['Del'] += out_item['del']
|
||||
rst['Sub'] += out_item['sub']
|
||||
rst['Snt'] += 1
|
||||
if out_item['wrong'] > 0:
|
||||
rst['wrong_sentences'] += 1
|
||||
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
|
||||
cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
|
||||
cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
|
||||
|
||||
if rst['Wrd'] > 0:
|
||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
|
||||
if rst['Snt'] > 0:
|
||||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
|
||||
|
||||
cer_detail_writer.write('\n')
|
||||
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
|
||||
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
|
||||
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
|
||||
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
|
||||
|
||||
|
||||
def compute_wer_by_line(hyp,
|
||||
ref):
|
||||
hyp = list(map(lambda x: x.lower(), hyp))
|
||||
ref = list(map(lambda x: x.lower(), ref))
|
||||
|
||||
len_hyp = len(hyp)
|
||||
len_ref = len(ref)
|
||||
|
||||
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
|
||||
|
||||
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
|
||||
|
||||
for i in range(len_hyp + 1):
|
||||
cost_matrix[i][0] = i
|
||||
for j in range(len_ref + 1):
|
||||
cost_matrix[0][j] = j
|
||||
|
||||
for i in range(1, len_hyp + 1):
|
||||
for j in range(1, len_ref + 1):
|
||||
if hyp[i - 1] == ref[j - 1]:
|
||||
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
|
||||
else:
|
||||
substitution = cost_matrix[i - 1][j - 1] + 1
|
||||
insertion = cost_matrix[i - 1][j] + 1
|
||||
deletion = cost_matrix[i][j - 1] + 1
|
||||
|
||||
compare_val = [substitution, insertion, deletion]
|
||||
|
||||
min_val = min(compare_val)
|
||||
operation_idx = compare_val.index(min_val) + 1
|
||||
cost_matrix[i][j] = min_val
|
||||
ops_matrix[i][j] = operation_idx
|
||||
|
||||
match_idx = []
|
||||
i = len_hyp
|
||||
j = len_ref
|
||||
rst = {
|
||||
'nwords': len_ref,
|
||||
'cor': 0,
|
||||
'wrong': 0,
|
||||
'ins': 0,
|
||||
'del': 0,
|
||||
'sub': 0
|
||||
}
|
||||
while i >= 0 or j >= 0:
|
||||
i_idx = max(0, i)
|
||||
j_idx = max(0, j)
|
||||
|
||||
if ops_matrix[i_idx][j_idx] == 0: # correct
|
||||
if i - 1 >= 0 and j - 1 >= 0:
|
||||
match_idx.append((j - 1, i - 1))
|
||||
rst['cor'] += 1
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 2: # insert
|
||||
i -= 1
|
||||
rst['ins'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 3: # delete
|
||||
j -= 1
|
||||
rst['del'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 1: # substitute
|
||||
i -= 1
|
||||
j -= 1
|
||||
rst['sub'] += 1
|
||||
|
||||
if i < 0 and j >= 0:
|
||||
rst['del'] += 1
|
||||
elif j < 0 and i >= 0:
|
||||
rst['ins'] += 1
|
||||
|
||||
match_idx.reverse()
|
||||
wrong_cnt = cost_matrix[len_hyp][len_ref]
|
||||
rst['wrong'] = wrong_cnt
|
||||
|
||||
return rst
|
||||
|
||||
def print_cer_detail(rst):
|
||||
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
|
||||
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
|
||||
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
|
||||
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 4:
|
||||
print("usage : python compute-wer.py test.ref test.hyp test.wer")
|
||||
sys.exit(0)
|
||||
|
||||
ref_file = sys.argv[1]
|
||||
hyp_file = sys.argv[2]
|
||||
cer_detail_file = sys.argv[3]
|
||||
compute_wer(ref_file, hyp_file, cer_detail_file)
|
||||
47
funasr_local/utils/config_argparse.py
Normal file
47
funasr_local/utils/config_argparse.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class ArgumentParser(argparse.ArgumentParser):
|
||||
"""Simple implementation of ArgumentParser supporting config file
|
||||
|
||||
This class is originated from https://github.com/bw2/ConfigArgParse,
|
||||
but this class is lack of some features that it has.
|
||||
|
||||
- Not supporting multiple config files
|
||||
- Automatically adding "--config" as an option.
|
||||
- Not supporting any formats other than yaml
|
||||
- Not checking argument type
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_argument("--config", help="Give config file in yaml format")
|
||||
|
||||
def parse_known_args(self, args=None, namespace=None):
|
||||
# Once parsing for setting from "--config"
|
||||
_args, _ = super().parse_known_args(args, namespace)
|
||||
if _args.config is not None:
|
||||
if not Path(_args.config).exists():
|
||||
self.error(f"No such file: {_args.config}")
|
||||
|
||||
with open(_args.config, "r", encoding="utf-8") as f:
|
||||
d = yaml.safe_load(f)
|
||||
if not isinstance(d, dict):
|
||||
self.error("Config file has non dict value: {_args.config}")
|
||||
|
||||
for key in d:
|
||||
for action in self._actions:
|
||||
if key == action.dest:
|
||||
break
|
||||
else:
|
||||
self.error(f"unrecognized arguments: {key} (from {_args.config})")
|
||||
|
||||
# NOTE(kamo): Ignore "--config" from a config file
|
||||
# NOTE(kamo): Unlike "configargparse", this module doesn't check type.
|
||||
# i.e. We can set any type value regardless of argument type.
|
||||
self.set_defaults(**d)
|
||||
return super().parse_known_args(args, namespace)
|
||||
57
funasr_local/utils/get_default_kwargs.py
Normal file
57
funasr_local/utils/get_default_kwargs.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import inspect
|
||||
|
||||
|
||||
class Invalid:
|
||||
"""Marker object for not serializable-object"""
|
||||
|
||||
|
||||
def get_default_kwargs(func):
|
||||
"""Get the default values of the input function.
|
||||
|
||||
Examples:
|
||||
>>> def func(a, b=3): pass
|
||||
>>> get_default_kwargs(func)
|
||||
{'b': 3}
|
||||
|
||||
"""
|
||||
|
||||
def yaml_serializable(value):
|
||||
# isinstance(x, tuple) includes namedtuple, so type is used here
|
||||
if type(value) is tuple:
|
||||
return yaml_serializable(list(value))
|
||||
elif isinstance(value, set):
|
||||
return yaml_serializable(list(value))
|
||||
elif isinstance(value, dict):
|
||||
if not all(isinstance(k, str) for k in value):
|
||||
return Invalid
|
||||
retval = {}
|
||||
for k, v in value.items():
|
||||
v2 = yaml_serializable(v)
|
||||
# Register only valid object
|
||||
if v2 not in (Invalid, inspect.Parameter.empty):
|
||||
retval[k] = v2
|
||||
return retval
|
||||
elif isinstance(value, list):
|
||||
retval = []
|
||||
for v in value:
|
||||
v2 = yaml_serializable(v)
|
||||
# If any elements in the list are invalid,
|
||||
# the list also becomes invalid
|
||||
if v2 is Invalid:
|
||||
return Invalid
|
||||
else:
|
||||
retval.append(v2)
|
||||
return retval
|
||||
elif value in (inspect.Parameter.empty, None):
|
||||
return value
|
||||
elif isinstance(value, (float, int, complex, bool, str, bytes)):
|
||||
return value
|
||||
else:
|
||||
return Invalid
|
||||
|
||||
# params: An ordered mapping of inspect.Parameter
|
||||
params = inspect.signature(func).parameters
|
||||
data = {p.name: p.default for p in params.values()}
|
||||
# Remove not yaml-serializable object
|
||||
data = yaml_serializable(data)
|
||||
return data
|
||||
192
funasr_local/utils/griffin_lim.py
Normal file
192
funasr_local/utils/griffin_lim.py
Normal file
@@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Griffin-Lim related modules."""
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import logging
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
from functools import partial
|
||||
from typeguard import check_argument_types
|
||||
from typing import Optional
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
EPS = 1e-10
|
||||
|
||||
|
||||
def logmel2linear(
|
||||
lmspc: np.ndarray,
|
||||
fs: int,
|
||||
n_fft: int,
|
||||
n_mels: int,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
) -> np.ndarray:
|
||||
"""Convert log Mel filterbank to linear spectrogram.
|
||||
|
||||
Args:
|
||||
lmspc: Log Mel filterbank (T, n_mels).
|
||||
fs: Sampling frequency.
|
||||
n_fft: The number of FFT points.
|
||||
n_mels: The number of mel basis.
|
||||
f_min: Minimum frequency to analyze.
|
||||
f_max: Maximum frequency to analyze.
|
||||
|
||||
Returns:
|
||||
Linear spectrogram (T, n_fft // 2 + 1).
|
||||
|
||||
"""
|
||||
assert lmspc.shape[1] == n_mels
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
mspc = np.power(10.0, lmspc)
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
inv_mel_basis = np.linalg.pinv(mel_basis)
|
||||
return np.maximum(EPS, np.dot(inv_mel_basis, mspc.T).T)
|
||||
|
||||
|
||||
def griffin_lim(
|
||||
spc: np.ndarray,
|
||||
n_fft: int,
|
||||
n_shift: int,
|
||||
win_length: int = None,
|
||||
window: Optional[str] = "hann",
|
||||
n_iter: Optional[int] = 32,
|
||||
) -> np.ndarray:
|
||||
"""Convert linear spectrogram into waveform using Griffin-Lim.
|
||||
|
||||
Args:
|
||||
spc: Linear spectrogram (T, n_fft // 2 + 1).
|
||||
n_fft: The number of FFT points.
|
||||
n_shift: Shift size in points.
|
||||
win_length: Window length in points.
|
||||
window: Window function type.
|
||||
n_iter: The number of iterations.
|
||||
|
||||
Returns:
|
||||
Reconstructed waveform (N,).
|
||||
|
||||
"""
|
||||
# assert the size of input linear spectrogram
|
||||
assert spc.shape[1] == n_fft // 2 + 1
|
||||
|
||||
if LooseVersion(librosa.__version__) >= LooseVersion("0.7.0"):
|
||||
# use librosa's fast Grriffin-Lim algorithm
|
||||
spc = np.abs(spc.T)
|
||||
y = librosa.griffinlim(
|
||||
S=spc,
|
||||
n_iter=n_iter,
|
||||
hop_length=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
center=True if spc.shape[1] > 1 else False,
|
||||
)
|
||||
else:
|
||||
# use slower version of Grriffin-Lim algorithm
|
||||
logging.warning(
|
||||
"librosa version is old. use slow version of Grriffin-Lim algorithm."
|
||||
"if you want to use fast Griffin-Lim, please update librosa via "
|
||||
"`source ./path.sh && pip install librosa==0.7.0`."
|
||||
)
|
||||
cspc = np.abs(spc).astype(np.complex).T
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*cspc.shape))
|
||||
y = librosa.istft(cspc * angles, n_shift, win_length, window=window)
|
||||
for i in range(n_iter):
|
||||
angles = np.exp(
|
||||
1j
|
||||
* np.angle(librosa.stft(y, n_fft, n_shift, win_length, window=window))
|
||||
)
|
||||
y = librosa.istft(cspc * angles, n_shift, win_length, window=window)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
# TODO(kan-bayashi): write as torch.nn.Module
|
||||
class Spectrogram2Waveform(object):
|
||||
"""Spectrogram to waveform conversion module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int,
|
||||
n_shift: int,
|
||||
fs: int = None,
|
||||
n_mels: int = None,
|
||||
win_length: int = None,
|
||||
window: Optional[str] = "hann",
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
griffin_lim_iters: Optional[int] = 8,
|
||||
):
|
||||
"""Initialize module.
|
||||
|
||||
Args:
|
||||
fs: Sampling frequency.
|
||||
n_fft: The number of FFT points.
|
||||
n_shift: Shift size in points.
|
||||
n_mels: The number of mel basis.
|
||||
win_length: Window length in points.
|
||||
window: Window function type.
|
||||
f_min: Minimum frequency to analyze.
|
||||
f_max: Maximum frequency to analyze.
|
||||
griffin_lim_iters: The number of iterations.
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
self.fs = fs
|
||||
self.logmel2linear = (
|
||||
partial(
|
||||
logmel2linear, fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
if n_mels is not None
|
||||
else None
|
||||
)
|
||||
self.griffin_lim = partial(
|
||||
griffin_lim,
|
||||
n_fft=n_fft,
|
||||
n_shift=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
n_iter=griffin_lim_iters,
|
||||
)
|
||||
self.params = dict(
|
||||
n_fft=n_fft,
|
||||
n_shift=n_shift,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
n_iter=griffin_lim_iters,
|
||||
)
|
||||
if n_mels is not None:
|
||||
self.params.update(fs=fs, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
|
||||
def __repr__(self):
|
||||
retval = f"{self.__class__.__name__}("
|
||||
for k, v in self.params.items():
|
||||
retval += f"{k}={v}, "
|
||||
retval += ")"
|
||||
return retval
|
||||
|
||||
def __call__(self, spc: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert spectrogram to waveform.
|
||||
|
||||
Args:
|
||||
spc: Log Mel filterbank (T_feats, n_mels)
|
||||
or linear spectrogram (T_feats, n_fft // 2 + 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Reconstructed waveform (T_wav,).
|
||||
|
||||
"""
|
||||
device = spc.device
|
||||
dtype = spc.dtype
|
||||
spc = spc.cpu().numpy()
|
||||
if self.logmel2linear is not None:
|
||||
spc = self.logmel2linear(spc)
|
||||
wav = self.griffin_lim(spc)
|
||||
return torch.tensor(wav).to(device=device, dtype=dtype)
|
||||
103
funasr_local/utils/job_runner.py
Normal file
103
funasr_local/utils/job_runner.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import print_function
|
||||
from multiprocessing import Pool
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
|
||||
class MultiProcessRunner:
|
||||
def __init__(self, fn):
|
||||
self.args = None
|
||||
self.process = fn
|
||||
|
||||
def run(self):
|
||||
parser = argparse.ArgumentParser("")
|
||||
# Task-independent options
|
||||
parser.add_argument("--nj", type=int, default=16)
|
||||
parser.add_argument("--debug", action="store_true", default=False)
|
||||
parser.add_argument("--no_pbar", action="store_true", default=False)
|
||||
parser.add_argument("--verbose", action="store_ture", default=False)
|
||||
|
||||
task_list, args = self.prepare(parser)
|
||||
result_list = self.pool_run(task_list, args)
|
||||
self.post(result_list, args)
|
||||
|
||||
def prepare(self, parser):
|
||||
raise NotImplementedError("Please implement the prepare function.")
|
||||
|
||||
def post(self, result_list, args):
|
||||
raise NotImplementedError("Please implement the post function.")
|
||||
|
||||
def pool_run(self, tasks, args):
|
||||
results = []
|
||||
if args.debug:
|
||||
one_result = self.process(tasks[0])
|
||||
results.append(one_result)
|
||||
else:
|
||||
pool = Pool(args.nj)
|
||||
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
|
||||
results.append(one_result)
|
||||
pool.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MultiProcessRunnerV2:
|
||||
def __init__(self, fn):
|
||||
self.args = None
|
||||
self.process = fn
|
||||
|
||||
def run(self):
|
||||
parser = argparse.ArgumentParser("")
|
||||
# Task-independent options
|
||||
parser.add_argument("--nj", type=int, default=16)
|
||||
parser.add_argument("--debug", action="store_true", default=False)
|
||||
parser.add_argument("--no_pbar", action="store_true", default=False)
|
||||
parser.add_argument("--verbose", action="store_true", default=False)
|
||||
|
||||
task_list, args = self.prepare(parser)
|
||||
chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
|
||||
if args.verbose:
|
||||
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
|
||||
subtask_list = [task_list[i*chunk_size: (i+1)*chunk_size] for i in range(args.nj)]
|
||||
result_list = self.pool_run(subtask_list, args)
|
||||
self.post(result_list, args)
|
||||
|
||||
def prepare(self, parser):
|
||||
raise NotImplementedError("Please implement the prepare function.")
|
||||
|
||||
def post(self, result_list, args):
|
||||
raise NotImplementedError("Please implement the post function.")
|
||||
|
||||
def pool_run(self, tasks, args):
|
||||
results = []
|
||||
if args.debug:
|
||||
one_result = self.process(tasks[0])
|
||||
results.append(one_result)
|
||||
else:
|
||||
pool = Pool(args.nj)
|
||||
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
|
||||
results.append(one_result)
|
||||
pool.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MultiProcessRunnerV3(MultiProcessRunnerV2):
|
||||
def run(self):
|
||||
parser = argparse.ArgumentParser("")
|
||||
# Task-independent options
|
||||
parser.add_argument("--nj", type=int, default=16)
|
||||
parser.add_argument("--debug", action="store_true", default=False)
|
||||
parser.add_argument("--no_pbar", action="store_true", default=False)
|
||||
parser.add_argument("--verbose", action="store_true", default=False)
|
||||
parser.add_argument("--sr", type=int, default=16000)
|
||||
|
||||
task_list, shared_param, args = self.prepare(parser)
|
||||
chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
|
||||
if args.verbose:
|
||||
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
|
||||
subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args)
|
||||
for i in range(args.nj)]
|
||||
result_list = self.pool_run(subtask_list, args)
|
||||
self.post(result_list, args)
|
||||
48
funasr_local/utils/misc.py
Normal file
48
funasr_local/utils/misc.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import io
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def statistic_model_parameters(model, prefix=None):
|
||||
var_dict = model.state_dict()
|
||||
numel = 0
|
||||
for i, key in enumerate(sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))):
|
||||
if prefix is None or key.startswith(prefix):
|
||||
numel += var_dict[key].numel()
|
||||
return numel
|
||||
|
||||
|
||||
def int2vec(x, vec_dim=8, dtype=np.int):
|
||||
b = ('{:0' + str(vec_dim) + 'b}').format(x)
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
|
||||
def seq2arr(seq, vec_dim=8):
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
|
||||
|
||||
def load_scp_as_dict(scp_path, value_type='str', kv_sep=" "):
|
||||
with io.open(scp_path, 'r', encoding='utf-8') as f:
|
||||
ret_dict = OrderedDict()
|
||||
for one_line in f.readlines():
|
||||
one_line = one_line.strip()
|
||||
pos = one_line.find(kv_sep)
|
||||
key, value = one_line[:pos], one_line[pos + 1:]
|
||||
if value_type == 'list':
|
||||
value = value.split(' ')
|
||||
ret_dict[key] = value
|
||||
return ret_dict
|
||||
|
||||
|
||||
def load_scp_as_list(scp_path, value_type='str', kv_sep=" "):
|
||||
with io.open(scp_path, 'r', encoding='utf8') as f:
|
||||
ret_dict = []
|
||||
for one_line in f.readlines():
|
||||
one_line = one_line.strip()
|
||||
pos = one_line.find(kv_sep)
|
||||
key, value = one_line[:pos], one_line[pos + 1:]
|
||||
if value_type == 'list':
|
||||
value = value.split(' ')
|
||||
ret_dict.append((key, value))
|
||||
return ret_dict
|
||||
35
funasr_local/utils/modelscope_param.py
Normal file
35
funasr_local/utils/modelscope_param.py
Normal file
@@ -0,0 +1,35 @@
|
||||
class modelscope_args():
|
||||
def __init__(self,
|
||||
task: str = "",
|
||||
model: str = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
data_path: str = None,
|
||||
output_dir: str = None,
|
||||
model_revision: str = None,
|
||||
dataset_type: str = "small",
|
||||
batch_bins: int = 2000,
|
||||
max_epoch: int = None,
|
||||
accum_grad: int = None,
|
||||
keep_nbest_models: int = None,
|
||||
optim: str = None,
|
||||
lr: float = None,
|
||||
scheduler: str = None,
|
||||
scheduler_conf: dict = None,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
):
|
||||
self.task = task
|
||||
self.model = model
|
||||
self.data_path = data_path
|
||||
self.output_dir = output_dir
|
||||
self.model_revision = model_revision
|
||||
self.dataset_type = dataset_type
|
||||
self.batch_bins = batch_bins
|
||||
self.max_epoch = max_epoch
|
||||
self.accum_grad = accum_grad
|
||||
self.keep_nbest_models = keep_nbest_models
|
||||
self.optim = optim
|
||||
self.lr = lr
|
||||
self.scheduler = scheduler
|
||||
self.scheduler_conf = scheduler_conf
|
||||
self.specaug = specaug
|
||||
self.specaug_conf = specaug_conf
|
||||
106
funasr_local/utils/nested_dict_action.py
Normal file
106
funasr_local/utils/nested_dict_action.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class NestedDictAction(argparse.Action):
|
||||
"""Action class to append items to dict object.
|
||||
|
||||
Examples:
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--conf', action=NestedDictAction,
|
||||
... default={'a': 4})
|
||||
>>> parser.parse_args(['--conf', 'a=3', '--conf', 'c=4'])
|
||||
Namespace(conf={'a': 3, 'c': 4})
|
||||
>>> parser.parse_args(['--conf', 'c.d=4'])
|
||||
Namespace(conf={'a': 4, 'c': {'d': 4}})
|
||||
>>> parser.parse_args(['--conf', 'c.d=4', '--conf', 'c=2'])
|
||||
Namespace(conf={'a': 4, 'c': 2})
|
||||
>>> parser.parse_args(['--conf', '{d: 5, e: 9}'])
|
||||
Namespace(conf={'d': 5, 'e': 9})
|
||||
|
||||
"""
|
||||
|
||||
_syntax = """Syntax:
|
||||
{op} <key>=<yaml-string>
|
||||
{op} <key>.<key2>=<yaml-string>
|
||||
{op} <python-dict>
|
||||
{op} <yaml-string>
|
||||
e.g.
|
||||
{op} a=4
|
||||
{op} a.b={{c: true}}
|
||||
{op} {{"c": True}}
|
||||
{op} {{a: 34.5}}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
option_strings,
|
||||
dest,
|
||||
nargs=None,
|
||||
default=None,
|
||||
choices=None,
|
||||
required=False,
|
||||
help=None,
|
||||
metavar=None,
|
||||
):
|
||||
super().__init__(
|
||||
option_strings=option_strings,
|
||||
dest=dest,
|
||||
nargs=nargs,
|
||||
default=copy.deepcopy(default),
|
||||
type=None,
|
||||
choices=choices,
|
||||
required=required,
|
||||
help=help,
|
||||
metavar=metavar,
|
||||
)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_strings=None):
|
||||
# --{option} a.b=3 -> {'a': {'b': 3}}
|
||||
if "=" in values:
|
||||
indict = copy.deepcopy(getattr(namespace, self.dest, {}))
|
||||
key, value = values.split("=", maxsplit=1)
|
||||
if not value.strip() == "":
|
||||
value = yaml.load(value, Loader=yaml.Loader)
|
||||
if not isinstance(indict, dict):
|
||||
indict = {}
|
||||
|
||||
keys = key.split(".")
|
||||
d = indict
|
||||
for idx, k in enumerate(keys):
|
||||
if idx == len(keys) - 1:
|
||||
d[k] = value
|
||||
else:
|
||||
if not isinstance(d.setdefault(k, {}), dict):
|
||||
# Remove the existing value and recreates as empty dict
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
|
||||
# Update the value
|
||||
setattr(namespace, self.dest, indict)
|
||||
else:
|
||||
try:
|
||||
# At the first, try eval(), i.e. Python syntax dict.
|
||||
# e.g. --{option} "{'a': 3}" -> {'a': 3}
|
||||
# This is workaround for internal behaviour of configargparse.
|
||||
value = eval(values, {}, {})
|
||||
if not isinstance(value, dict):
|
||||
syntax = self._syntax.format(op=option_strings)
|
||||
mes = f"must be interpreted as dict: but got {values}\n{syntax}"
|
||||
raise argparse.ArgumentTypeError(self, mes)
|
||||
except Exception:
|
||||
# and the second, try yaml.load
|
||||
value = yaml.load(values, Loader=yaml.Loader)
|
||||
if not isinstance(value, dict):
|
||||
syntax = self._syntax.format(op=option_strings)
|
||||
mes = f"must be interpreted as dict: but got {values}\n{syntax}"
|
||||
raise argparse.ArgumentError(self, mes)
|
||||
|
||||
d = getattr(namespace, self.dest, None)
|
||||
if isinstance(d, dict):
|
||||
d.update(value)
|
||||
else:
|
||||
# Remove existing params, and overwrite
|
||||
setattr(namespace, self.dest, value)
|
||||
245
funasr_local/utils/postprocess_utils.py
Normal file
245
funasr_local/utils/postprocess_utils.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import string
|
||||
import logging
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
||||
def isChinese(ch: str):
|
||||
if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039' or ch == '@':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def isAllChinese(word: Union[List[Any], str]):
|
||||
word_lists = []
|
||||
for i in word:
|
||||
cur = i.replace(' ', '')
|
||||
cur = cur.replace('</s>', '')
|
||||
cur = cur.replace('<s>', '')
|
||||
cur = cur.replace('<unk>', '')
|
||||
cur = cur.replace('<OOV>', '')
|
||||
word_lists.append(cur)
|
||||
|
||||
if len(word_lists) == 0:
|
||||
return False
|
||||
|
||||
for ch in word_lists:
|
||||
if isChinese(ch) is False:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def isAllAlpha(word: Union[List[Any], str]):
|
||||
word_lists = []
|
||||
for i in word:
|
||||
cur = i.replace(' ', '')
|
||||
cur = cur.replace('</s>', '')
|
||||
cur = cur.replace('<s>', '')
|
||||
cur = cur.replace('<unk>', '')
|
||||
cur = cur.replace('<OOV>', '')
|
||||
word_lists.append(cur)
|
||||
|
||||
if len(word_lists) == 0:
|
||||
return False
|
||||
|
||||
for ch in word_lists:
|
||||
if ch.isalpha() is False and ch != "'":
|
||||
return False
|
||||
elif ch.isalpha() is True and isChinese(ch) is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# def abbr_dispose(words: List[Any]) -> List[Any]:
|
||||
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
|
||||
words_size = len(words)
|
||||
word_lists = []
|
||||
abbr_begin = []
|
||||
abbr_end = []
|
||||
last_num = -1
|
||||
ts_lists = []
|
||||
ts_nums = []
|
||||
ts_index = 0
|
||||
for num in range(words_size):
|
||||
if num <= last_num:
|
||||
continue
|
||||
|
||||
if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
|
||||
if num + 1 < words_size and words[
|
||||
num + 1] == ' ' and num + 2 < words_size and len(
|
||||
words[num +
|
||||
2]) == 1 and words[num +
|
||||
2].encode('utf-8').isalpha():
|
||||
# found the begin of abbr
|
||||
abbr_begin.append(num)
|
||||
num += 2
|
||||
abbr_end.append(num)
|
||||
# to find the end of abbr
|
||||
while True:
|
||||
num += 1
|
||||
if num < words_size and words[num] == ' ':
|
||||
num += 1
|
||||
if num < words_size and len(
|
||||
words[num]) == 1 and words[num].encode(
|
||||
'utf-8').isalpha():
|
||||
abbr_end.pop()
|
||||
abbr_end.append(num)
|
||||
last_num = num
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
for num in range(words_size):
|
||||
if words[num] == ' ':
|
||||
ts_nums.append(ts_index)
|
||||
else:
|
||||
ts_nums.append(ts_index)
|
||||
ts_index += 1
|
||||
last_num = -1
|
||||
for num in range(words_size):
|
||||
if num <= last_num:
|
||||
continue
|
||||
|
||||
if num in abbr_begin:
|
||||
if time_stamp is not None:
|
||||
begin = time_stamp[ts_nums[num]][0]
|
||||
abbr_word = words[num].upper()
|
||||
num += 1
|
||||
while num < words_size:
|
||||
if num in abbr_end:
|
||||
abbr_word += words[num].upper()
|
||||
last_num = num
|
||||
break
|
||||
else:
|
||||
if words[num].encode('utf-8').isalpha():
|
||||
abbr_word += words[num].upper()
|
||||
num += 1
|
||||
word_lists.append(abbr_word)
|
||||
if time_stamp is not None:
|
||||
end = time_stamp[ts_nums[num]][1]
|
||||
ts_lists.append([begin, end])
|
||||
else:
|
||||
word_lists.append(words[num])
|
||||
if time_stamp is not None and words[num] != ' ':
|
||||
begin = time_stamp[ts_nums[num]][0]
|
||||
end = time_stamp[ts_nums[num]][1]
|
||||
ts_lists.append([begin, end])
|
||||
begin = end
|
||||
|
||||
if time_stamp is not None:
|
||||
return word_lists, ts_lists
|
||||
else:
|
||||
return word_lists
|
||||
|
||||
|
||||
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
|
||||
middle_lists = []
|
||||
word_lists = []
|
||||
word_item = ''
|
||||
ts_lists = []
|
||||
|
||||
# wash words lists
|
||||
for i in words:
|
||||
word = ''
|
||||
if isinstance(i, str):
|
||||
word = i
|
||||
else:
|
||||
word = i.decode('utf-8')
|
||||
|
||||
if word in ['<s>', '</s>', '<unk>', '<OOV>']:
|
||||
continue
|
||||
else:
|
||||
middle_lists.append(word)
|
||||
|
||||
# all chinese characters
|
||||
if isAllChinese(middle_lists):
|
||||
for i, ch in enumerate(middle_lists):
|
||||
word_lists.append(ch.replace(' ', ''))
|
||||
if time_stamp is not None:
|
||||
ts_lists = time_stamp
|
||||
|
||||
# all alpha characters
|
||||
elif isAllAlpha(middle_lists):
|
||||
ts_flag = True
|
||||
for i, ch in enumerate(middle_lists):
|
||||
if ts_flag and time_stamp is not None:
|
||||
begin = time_stamp[i][0]
|
||||
end = time_stamp[i][1]
|
||||
word = ''
|
||||
if '@@' in ch:
|
||||
word = ch.replace('@@', '')
|
||||
word_item += word
|
||||
if time_stamp is not None:
|
||||
ts_flag = False
|
||||
end = time_stamp[i][1]
|
||||
else:
|
||||
word_item += ch
|
||||
word_lists.append(word_item)
|
||||
word_lists.append(' ')
|
||||
word_item = ''
|
||||
if time_stamp is not None:
|
||||
ts_flag = True
|
||||
end = time_stamp[i][1]
|
||||
ts_lists.append([begin, end])
|
||||
begin = end
|
||||
|
||||
# mix characters
|
||||
else:
|
||||
alpha_blank = False
|
||||
ts_flag = True
|
||||
begin = -1
|
||||
end = -1
|
||||
for i, ch in enumerate(middle_lists):
|
||||
if ts_flag and time_stamp is not None:
|
||||
begin = time_stamp[i][0]
|
||||
end = time_stamp[i][1]
|
||||
word = ''
|
||||
if isAllChinese(ch):
|
||||
if alpha_blank is True:
|
||||
word_lists.pop()
|
||||
word_lists.append(ch)
|
||||
alpha_blank = False
|
||||
if time_stamp is not None:
|
||||
ts_flag = True
|
||||
ts_lists.append([begin, end])
|
||||
begin = end
|
||||
elif '@@' in ch:
|
||||
word = ch.replace('@@', '')
|
||||
word_item += word
|
||||
alpha_blank = False
|
||||
if time_stamp is not None:
|
||||
ts_flag = False
|
||||
end = time_stamp[i][1]
|
||||
elif isAllAlpha(ch):
|
||||
word_item += ch
|
||||
word_lists.append(word_item)
|
||||
word_lists.append(' ')
|
||||
word_item = ''
|
||||
alpha_blank = True
|
||||
if time_stamp is not None:
|
||||
ts_flag = True
|
||||
end = time_stamp[i][1]
|
||||
ts_lists.append([begin, end])
|
||||
begin = end
|
||||
else:
|
||||
word_lists.append(ch)
|
||||
|
||||
if time_stamp is not None:
|
||||
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
|
||||
real_word_lists = []
|
||||
for ch in word_lists:
|
||||
if ch != ' ':
|
||||
real_word_lists.append(ch)
|
||||
sentence = ' '.join(real_word_lists).strip()
|
||||
return sentence, ts_lists, real_word_lists
|
||||
else:
|
||||
word_lists = abbr_dispose(word_lists)
|
||||
real_word_lists = []
|
||||
for ch in word_lists:
|
||||
if ch != ' ':
|
||||
real_word_lists.append(ch)
|
||||
sentence = ''.join(word_lists).strip()
|
||||
return sentence, real_word_lists
|
||||
75
funasr_local/utils/sized_dict.py
Normal file
75
funasr_local/utils/sized_dict.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import collections
|
||||
import sys
|
||||
|
||||
from torch import multiprocessing
|
||||
|
||||
|
||||
def get_size(obj, seen=None):
|
||||
"""Recursively finds size of objects
|
||||
|
||||
Taken from https://github.com/bosswissam/pysize
|
||||
|
||||
"""
|
||||
|
||||
size = sys.getsizeof(obj)
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return 0
|
||||
|
||||
# Important mark as seen *before* entering recursion to gracefully handle
|
||||
# self-referential objects
|
||||
seen.add(obj_id)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
size += sum([get_size(v, seen) for v in obj.values()])
|
||||
size += sum([get_size(k, seen) for k in obj.keys()])
|
||||
elif hasattr(obj, "__dict__"):
|
||||
size += get_size(obj.__dict__, seen)
|
||||
elif isinstance(obj, (list, set, tuple)):
|
||||
size += sum([get_size(i, seen) for i in obj])
|
||||
|
||||
return size
|
||||
|
||||
|
||||
class SizedDict(collections.abc.MutableMapping):
|
||||
def __init__(self, shared: bool = False, data: dict = None):
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
if shared:
|
||||
# NOTE(kamo): Don't set manager as a field because Manager, which includes
|
||||
# weakref object, causes following error with method="spawn",
|
||||
# "TypeError: can't pickle weakref objects"
|
||||
self.cache = multiprocessing.Manager().dict(**data)
|
||||
else:
|
||||
self.manager = None
|
||||
self.cache = dict(**data)
|
||||
self.size = 0
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key in self.cache:
|
||||
self.size -= get_size(self.cache[key])
|
||||
else:
|
||||
self.size += sys.getsizeof(key)
|
||||
self.size += get_size(value)
|
||||
self.cache[key] = value
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.cache[key]
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.size -= get_size(self.cache[key])
|
||||
self.size -= sys.getsizeof(key)
|
||||
del self.cache[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cache)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.cache
|
||||
|
||||
def __len__(self):
|
||||
return len(self.cache)
|
||||
320
funasr_local/utils/timestamp_tools.py
Normal file
320
funasr_local/utils/timestamp_tools.py
Normal file
@@ -0,0 +1,320 @@
|
||||
from itertools import zip_longest
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import codecs
|
||||
import logging
|
||||
import edit_distance
|
||||
import argparse
|
||||
import pdb
|
||||
import numpy as np
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
||||
def ts_prediction_lfr6_standard(us_alphas,
|
||||
us_peaks,
|
||||
char_list,
|
||||
vad_offset=0.0,
|
||||
force_time_shift=-1.5,
|
||||
sil_in_str=True
|
||||
):
|
||||
if not len(char_list):
|
||||
return "", []
|
||||
START_END_THRESHOLD = 5
|
||||
MAX_TOKEN_DURATION = 12
|
||||
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
||||
if len(us_alphas.shape) == 2:
|
||||
_, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
|
||||
else:
|
||||
_, peaks = us_alphas, us_peaks
|
||||
num_frames = peaks.shape[0]
|
||||
if char_list[-1] == '</s>':
|
||||
char_list = char_list[:-1]
|
||||
timestamp_list = []
|
||||
new_char_list = []
|
||||
# for bicif model trained with large data, cif2 actually fires when a character starts
|
||||
# so treat the frames between two peaks as the duration of the former token
|
||||
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||
num_peak = len(fire_place)
|
||||
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
||||
# begin silence
|
||||
if fire_place[0] > START_END_THRESHOLD:
|
||||
# char_list.insert(0, '<sil>')
|
||||
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
|
||||
new_char_list.append('<sil>')
|
||||
# tokens timestamp
|
||||
for i in range(len(fire_place)-1):
|
||||
new_char_list.append(char_list[i])
|
||||
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
|
||||
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||
else:
|
||||
# cut the duration to token and sil of the 0-weight frames last long
|
||||
_split = fire_place[i] + MAX_TOKEN_DURATION
|
||||
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
|
||||
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||
new_char_list.append('<sil>')
|
||||
# tail token and end silence
|
||||
# new_char_list.append(char_list[-1])
|
||||
if num_frames - fire_place[-1] > START_END_THRESHOLD:
|
||||
_end = (num_frames + fire_place[-1]) * 0.5
|
||||
# _end = fire_place[-1]
|
||||
timestamp_list[-1][1] = _end*TIME_RATE
|
||||
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
|
||||
new_char_list.append("<sil>")
|
||||
else:
|
||||
timestamp_list[-1][1] = num_frames*TIME_RATE
|
||||
if vad_offset: # add offset time in model with vad
|
||||
for i in range(len(timestamp_list)):
|
||||
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
|
||||
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
||||
res_txt = ""
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
#if char != '<sil>':
|
||||
if not sil_in_str and char == '<sil>': continue
|
||||
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
||||
res = []
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
if char != '<sil>':
|
||||
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
|
||||
return res_txt, res
|
||||
|
||||
|
||||
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
|
||||
res = []
|
||||
if text_postprocessed is None:
|
||||
return res
|
||||
if time_stamp_postprocessed is None:
|
||||
return res
|
||||
if len(time_stamp_postprocessed) == 0:
|
||||
return res
|
||||
if len(text_postprocessed) == 0:
|
||||
return res
|
||||
|
||||
if punc_id_list is None or len(punc_id_list) == 0:
|
||||
res.append({
|
||||
'text': text_postprocessed.split(),
|
||||
"start": time_stamp_postprocessed[0][0],
|
||||
"end": time_stamp_postprocessed[-1][1]
|
||||
})
|
||||
return res
|
||||
if len(punc_id_list) != len(time_stamp_postprocessed):
|
||||
print(" warning length mistach!!!!!!")
|
||||
sentence_text = ''
|
||||
sentence_start = time_stamp_postprocessed[0][0]
|
||||
sentence_end = time_stamp_postprocessed[0][1]
|
||||
texts = text_postprocessed.split()
|
||||
punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None))
|
||||
for punc_stamp_text in punc_stamp_text_list:
|
||||
punc_id, time_stamp, text = punc_stamp_text
|
||||
sentence_text += text if text is not None else ''
|
||||
punc_id = int(punc_id) if punc_id is not None else 1
|
||||
sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
|
||||
|
||||
if punc_id == 2:
|
||||
sentence_text += ','
|
||||
res.append({
|
||||
'text': sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end
|
||||
})
|
||||
sentence_text = ''
|
||||
sentence_start = sentence_end
|
||||
elif punc_id == 3:
|
||||
sentence_text += '.'
|
||||
res.append({
|
||||
'text': sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end
|
||||
})
|
||||
sentence_text = ''
|
||||
sentence_start = sentence_end
|
||||
elif punc_id == 4:
|
||||
sentence_text += '?'
|
||||
res.append({
|
||||
'text': sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end
|
||||
})
|
||||
sentence_text = ''
|
||||
sentence_start = sentence_end
|
||||
return res
|
||||
|
||||
|
||||
class AverageShiftCalculator():
|
||||
def __init__(self):
|
||||
logging.warning("Calculating average shift.")
|
||||
def __call__(self, file1, file2):
|
||||
uttid_list1, ts_dict1 = self.read_timestamps(file1)
|
||||
uttid_list2, ts_dict2 = self.read_timestamps(file2)
|
||||
uttid_intersection = self._intersection(uttid_list1, uttid_list2)
|
||||
res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
|
||||
logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
|
||||
logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
|
||||
|
||||
def _intersection(self, list1, list2):
|
||||
set1 = set(list1)
|
||||
set2 = set(list2)
|
||||
if set1 == set2:
|
||||
logging.warning("Uttid same checked.")
|
||||
return set1
|
||||
itsc = list(set1 & set2)
|
||||
logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
|
||||
return itsc
|
||||
|
||||
def read_timestamps(self, file):
|
||||
# read timestamps file in standard format
|
||||
uttid_list = []
|
||||
ts_dict = {}
|
||||
with codecs.open(file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
text = ''
|
||||
ts_list = []
|
||||
line = line.rstrip()
|
||||
uttid = line.split()[0]
|
||||
uttid_list.append(uttid)
|
||||
body = " ".join(line.split()[1:])
|
||||
for pd in body.split(';'):
|
||||
if not len(pd): continue
|
||||
# pdb.set_trace()
|
||||
char, start, end = pd.lstrip(" ").split(' ')
|
||||
text += char + ','
|
||||
ts_list.append((float(start), float(end)))
|
||||
# ts_lists.append(ts_list)
|
||||
ts_dict[uttid] = (text[:-1], ts_list)
|
||||
logging.warning("File {} read done.".format(file))
|
||||
return uttid_list, ts_dict
|
||||
|
||||
def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
|
||||
shift_time = 0
|
||||
for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
|
||||
shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
|
||||
num_tokens = len(filtered_timestamp_list1)
|
||||
return shift_time, num_tokens
|
||||
|
||||
def as_cal(self, uttid_list, ts_dict1, ts_dict2):
|
||||
# calculate average shift between timestamp1 and timestamp2
|
||||
# when characters differ, use edit distance alignment
|
||||
# and calculate the error between the same characters
|
||||
self._accumlated_shift = 0
|
||||
self._accumlated_tokens = 0
|
||||
self.max_shift = 0
|
||||
self.max_shift_uttid = None
|
||||
for uttid in uttid_list:
|
||||
(t1, ts1) = ts_dict1[uttid]
|
||||
(t2, ts2) = ts_dict2[uttid]
|
||||
_align, _align2, _align3 = [], [], []
|
||||
fts1, fts2 = [], []
|
||||
_t1, _t2 = [], []
|
||||
sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
|
||||
s = sm.get_opcodes()
|
||||
for j in range(len(s)):
|
||||
if s[j][0] == "replace" or s[j][0] == "insert":
|
||||
_align.append(0)
|
||||
if s[j][0] == "replace" or s[j][0] == "delete":
|
||||
_align3.append(0)
|
||||
elif s[j][0] == "equal":
|
||||
_align.append(1)
|
||||
_align3.append(1)
|
||||
else:
|
||||
continue
|
||||
# use s to index t2
|
||||
for a, ts , t in zip(_align, ts2, t2.split(',')):
|
||||
if a:
|
||||
fts2.append(ts)
|
||||
_t2.append(t)
|
||||
sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
|
||||
s = sm2.get_opcodes()
|
||||
for j in range(len(s)):
|
||||
if s[j][0] == "replace" or s[j][0] == "insert":
|
||||
_align2.append(0)
|
||||
elif s[j][0] == "equal":
|
||||
_align2.append(1)
|
||||
else:
|
||||
continue
|
||||
# use s2 tp index t1
|
||||
for a, ts, t in zip(_align3, ts1, t1.split(',')):
|
||||
if a:
|
||||
fts1.append(ts)
|
||||
_t1.append(t)
|
||||
if len(fts1) == len(fts2):
|
||||
shift_time, num_tokens = self._shift(fts1, fts2)
|
||||
self._accumlated_shift += shift_time
|
||||
self._accumlated_tokens += num_tokens
|
||||
if shift_time/num_tokens > self.max_shift:
|
||||
self.max_shift = shift_time/num_tokens
|
||||
self.max_shift_uttid = uttid
|
||||
else:
|
||||
logging.warning("length mismatch")
|
||||
return self._accumlated_shift / self._accumlated_tokens
|
||||
|
||||
|
||||
def convert_external_alphas(alphas_file, text_file, output_file):
|
||||
from funasr_local.models.predictor.cif import cif_wo_hidden
|
||||
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
|
||||
for line1, line2 in zip(f1.readlines(), f2.readlines()):
|
||||
line1 = line1.rstrip()
|
||||
line2 = line2.rstrip()
|
||||
assert line1.split()[0] == line2.split()[0]
|
||||
uttid = line1.split()[0]
|
||||
alphas = [float(i) for i in line1.split()[1:]]
|
||||
new_alphas = np.array(remove_chunk_padding(alphas))
|
||||
new_alphas[-1] += 1e-4
|
||||
text = line2.split()[1:]
|
||||
if len(text) + 1 != int(new_alphas.sum()):
|
||||
# force resize
|
||||
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
|
||||
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
|
||||
if " " in text:
|
||||
text = text.split()
|
||||
else:
|
||||
text = [i for i in text]
|
||||
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
|
||||
force_time_shift=-7.0,
|
||||
sil_in_str=False)
|
||||
f3.write("{} {}\n".format(uttid, res_str))
|
||||
|
||||
|
||||
def remove_chunk_padding(alphas):
|
||||
# remove the padding part in alphas if using chunk paraformer for GPU
|
||||
START_ZERO = 45
|
||||
MID_ZERO = 75
|
||||
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
|
||||
alphas = alphas[START_ZERO:] # remove the padding at beginning
|
||||
new_alphas = []
|
||||
while True:
|
||||
new_alphas = new_alphas + alphas[:REAL_FRAMES]
|
||||
alphas = alphas[REAL_FRAMES+MID_ZERO:]
|
||||
if len(alphas) < REAL_FRAMES: break
|
||||
return new_alphas
|
||||
|
||||
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.mode == 'cal_aas':
|
||||
asc = AverageShiftCalculator()
|
||||
asc(args.input, args.input2)
|
||||
elif args.mode == 'read_ext_alphas':
|
||||
convert_external_alphas(args.input, args.input2, args.output)
|
||||
else:
|
||||
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='timestamp tools')
|
||||
parser.add_argument('--mode',
|
||||
default=None,
|
||||
type=str,
|
||||
choices=SUPPORTED_MODES,
|
||||
help='timestamp related toolbox')
|
||||
parser.add_argument('--input', default=None, type=str, help='input file path')
|
||||
parser.add_argument('--output', default=None, type=str, help='output file name')
|
||||
parser.add_argument('--input2', default=None, type=str, help='input2 file path')
|
||||
parser.add_argument('--kaldi-ts-type',
|
||||
default='v2',
|
||||
type=str,
|
||||
choices=['v0', 'v1', 'v2'],
|
||||
help='kaldi timestamp to write')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
149
funasr_local/utils/types.py
Normal file
149
funasr_local/utils/types.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from distutils.util import strtobool
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
|
||||
|
||||
def str2bool(value: str) -> bool:
|
||||
return bool(strtobool(value))
|
||||
|
||||
|
||||
def remove_parenthesis(value: str):
|
||||
value = value.strip()
|
||||
if value.startswith("(") and value.endswith(")"):
|
||||
value = value[1:-1]
|
||||
elif value.startswith("[") and value.endswith("]"):
|
||||
value = value[1:-1]
|
||||
return value
|
||||
|
||||
|
||||
def remove_quotes(value: str):
|
||||
value = value.strip()
|
||||
if value.startswith('"') and value.endswith('"'):
|
||||
value = value[1:-1]
|
||||
elif value.startswith("'") and value.endswith("'"):
|
||||
value = value[1:-1]
|
||||
return value
|
||||
|
||||
|
||||
def int_or_none(value: str) -> Optional[int]:
|
||||
"""int_or_none.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--foo', type=int_or_none)
|
||||
>>> parser.parse_args(['--foo', '456'])
|
||||
Namespace(foo=456)
|
||||
>>> parser.parse_args(['--foo', 'none'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'null'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'nil'])
|
||||
Namespace(foo=None)
|
||||
|
||||
"""
|
||||
if value.strip().lower() in ("none", "null", "nil"):
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def float_or_none(value: str) -> Optional[float]:
|
||||
"""float_or_none.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--foo', type=float_or_none)
|
||||
>>> parser.parse_args(['--foo', '4.5'])
|
||||
Namespace(foo=4.5)
|
||||
>>> parser.parse_args(['--foo', 'none'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'null'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'nil'])
|
||||
Namespace(foo=None)
|
||||
|
||||
"""
|
||||
if value.strip().lower() in ("none", "null", "nil"):
|
||||
return None
|
||||
return float(value)
|
||||
|
||||
|
||||
def humanfriendly_parse_size_or_none(value) -> Optional[float]:
|
||||
if value.strip().lower() in ("none", "null", "nil"):
|
||||
return None
|
||||
return humanfriendly.parse_size(value)
|
||||
|
||||
|
||||
def str_or_int(value: str) -> Union[str, int]:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
|
||||
def str_or_none(value: str) -> Optional[str]:
|
||||
"""str_or_none.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--foo', type=str_or_none)
|
||||
>>> parser.parse_args(['--foo', 'aaa'])
|
||||
Namespace(foo='aaa')
|
||||
>>> parser.parse_args(['--foo', 'none'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'null'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'nil'])
|
||||
Namespace(foo=None)
|
||||
|
||||
"""
|
||||
if value.strip().lower() in ("none", "null", "nil"):
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def str2pair_str(value: str) -> Tuple[str, str]:
|
||||
"""str2pair_str.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> str2pair_str('abc,def ')
|
||||
('abc', 'def')
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--foo', type=str2pair_str)
|
||||
>>> parser.parse_args(['--foo', 'abc,def'])
|
||||
Namespace(foo=('abc', 'def'))
|
||||
|
||||
"""
|
||||
value = remove_parenthesis(value)
|
||||
a, b = value.split(",")
|
||||
|
||||
# Workaround for configargparse issues:
|
||||
# If the list values are given from yaml file,
|
||||
# the value givent to type() is shaped as python-list,
|
||||
# e.g. ['a', 'b', 'c'],
|
||||
# so we need to remove double quotes from it.
|
||||
return remove_quotes(a), remove_quotes(b)
|
||||
|
||||
|
||||
def str2triple_str(value: str) -> Tuple[str, str, str]:
|
||||
"""str2triple_str.
|
||||
|
||||
Examples:
|
||||
>>> str2triple_str('abc,def ,ghi')
|
||||
('abc', 'def', 'ghi')
|
||||
"""
|
||||
value = remove_parenthesis(value)
|
||||
a, b, c = value.split(",")
|
||||
|
||||
# Workaround for configargparse issues:
|
||||
# If the list values are given from yaml file,
|
||||
# the value givent to type() is shaped as python-list,
|
||||
# e.g. ['a', 'b', 'c'],
|
||||
# so we need to remove quotes from it.
|
||||
return remove_quotes(a), remove_quotes(b), remove_quotes(c)
|
||||
321
funasr_local/utils/wav_utils.py
Normal file
321
funasr_local/utils/wav_utils.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from multiprocessing import Pool
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
# import kaldiio
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
|
||||
def ndarray_resample(audio_in: np.ndarray,
|
||||
fs_in: int = 16000,
|
||||
fs_out: int = 16000) -> np.ndarray:
|
||||
audio_out = audio_in
|
||||
if fs_in != fs_out:
|
||||
audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
|
||||
return audio_out
|
||||
|
||||
|
||||
def torch_resample(audio_in: torch.Tensor,
|
||||
fs_in: int = 16000,
|
||||
fs_out: int = 16000) -> torch.Tensor:
|
||||
audio_out = audio_in
|
||||
if fs_in != fs_out:
|
||||
audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
|
||||
new_freq=fs_out)(audio_in)
|
||||
return audio_out
|
||||
|
||||
|
||||
def extract_CMVN_featrures(mvn_file):
|
||||
"""
|
||||
extract CMVN from cmvn.ark
|
||||
"""
|
||||
|
||||
if not os.path.exists(mvn_file):
|
||||
return None
|
||||
try:
|
||||
cmvn = kaldiio.load_mat(mvn_file)
|
||||
means = []
|
||||
variance = []
|
||||
|
||||
for i in range(cmvn.shape[1] - 1):
|
||||
means.append(float(cmvn[0][i]))
|
||||
|
||||
count = float(cmvn[0][-1])
|
||||
|
||||
for i in range(cmvn.shape[1] - 1):
|
||||
variance.append(float(cmvn[1][i]))
|
||||
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
except Exception:
|
||||
cmvn = extract_CMVN_features_txt(mvn_file)
|
||||
return cmvn
|
||||
|
||||
|
||||
def extract_CMVN_features_txt(mvn_file): # noqa
|
||||
with open(mvn_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
add_shift_list = []
|
||||
rescale_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == '<AddShift>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
add_shift_line = line_item[3:(len(line_item) - 1)]
|
||||
add_shift_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == '<Rescale>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
rescale_list = list(rescale_line)
|
||||
continue
|
||||
add_shift_list_f = [float(s) for s in add_shift_list]
|
||||
rescale_list_f = [float(s) for s in rescale_list]
|
||||
cmvn = np.array([add_shift_list_f, rescale_list_f])
|
||||
return cmvn
|
||||
|
||||
|
||||
def build_LFR_features(inputs, m=7, n=6): # noqa
|
||||
"""
|
||||
Actually, this implements stacking frames and skipping frames.
|
||||
if m = 1 and n = 1, just return the origin features.
|
||||
if m = 1 and n > 1, it works like skipping.
|
||||
if m > 1 and n = 1, it works like stacking but only support right frames.
|
||||
if m > 1 and n > 1, it works like LFR.
|
||||
|
||||
Args:
|
||||
inputs_batch: inputs is T x D np.ndarray
|
||||
m: number of frames to stack
|
||||
n: number of frames to skip
|
||||
"""
|
||||
# LFR_inputs_batch = []
|
||||
# for inputs in inputs_batch:
|
||||
LFR_inputs = []
|
||||
T = inputs.shape[0]
|
||||
T_lfr = int(np.ceil(T / n))
|
||||
left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
|
||||
inputs = np.vstack((left_padding, inputs))
|
||||
T = T + (m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if m <= T - i * n:
|
||||
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
|
||||
else: # process last LFR frame
|
||||
num_padding = m - (T - i * n)
|
||||
frame = np.hstack(inputs[i * n:])
|
||||
for _ in range(num_padding):
|
||||
frame = np.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
return np.vstack(LFR_inputs)
|
||||
|
||||
|
||||
def compute_fbank(wav_file,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0,
|
||||
is_pcm=False,
|
||||
fs: Union[int, Dict[Any, int]] = 16000):
|
||||
audio_sr: int = 16000
|
||||
model_sr: int = 16000
|
||||
if isinstance(fs, int):
|
||||
model_sr = fs
|
||||
audio_sr = fs
|
||||
else:
|
||||
model_sr = fs['model_fs']
|
||||
audio_sr = fs['audio_fs']
|
||||
|
||||
if is_pcm is True:
|
||||
# byte(PCM16) to float32, and resample
|
||||
value = wav_file
|
||||
middle_data = np.frombuffer(value, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in 'iu':
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype('float32')
|
||||
if dtype.kind != 'f':
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2 ** (i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
waveform = np.frombuffer(
|
||||
(middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
waveform = ndarray_resample(waveform, audio_sr, model_sr)
|
||||
waveform = torch.from_numpy(waveform.reshape(1, -1))
|
||||
else:
|
||||
# load pcm from wav, and resample
|
||||
waveform, audio_sr = torchaudio.load(wav_file)
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = torch_resample(waveform, audio_sr, model_sr)
|
||||
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
window_type='hamming',
|
||||
sample_frequency=model_sr)
|
||||
|
||||
input_feats = mat
|
||||
|
||||
return input_feats
|
||||
|
||||
|
||||
def wav2num_frame(wav_path, frontend_conf):
|
||||
waveform, sampling_rate = torchaudio.load(wav_path)
|
||||
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
|
||||
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
|
||||
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
|
||||
return n_frames, feature_dim, speech_length
|
||||
|
||||
|
||||
def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_max, idx):
|
||||
wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
|
||||
shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
|
||||
with open(wav_scp_file) as f:
|
||||
lines = f.readlines()
|
||||
with open(shape_file, "w") as f:
|
||||
for line in lines:
|
||||
sample_name, wav_path = line.strip().split()
|
||||
n_frames, feature_dim, speech_length = wav2num_frame(wav_path, frontend_conf)
|
||||
write_flag = True
|
||||
if speech_length_min > 0 and speech_length < speech_length_min:
|
||||
write_flag = False
|
||||
if speech_length_max > 0 and speech_length > speech_length_max:
|
||||
write_flag = False
|
||||
if write_flag:
|
||||
f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
|
||||
f.flush()
|
||||
|
||||
|
||||
def calc_shape(data_dir, dataset, frontend_conf, speech_length_min=-1, speech_length_max=-1, nj=32):
|
||||
shape_path = os.path.join(data_dir, dataset, "shape_files")
|
||||
if os.path.exists(shape_path):
|
||||
assert os.path.exists(os.path.join(data_dir, dataset, "speech_shape"))
|
||||
print('Shape file for small dataset already exists.')
|
||||
return
|
||||
os.makedirs(shape_path, exist_ok=True)
|
||||
|
||||
# split
|
||||
wav_scp_file = os.path.join(data_dir, dataset, "wav.scp")
|
||||
with open(wav_scp_file) as f:
|
||||
lines = f.readlines()
|
||||
num_lines = len(lines)
|
||||
num_job_lines = num_lines // nj
|
||||
start = 0
|
||||
for i in range(nj):
|
||||
end = start + num_job_lines
|
||||
file = os.path.join(shape_path, "wav.scp.{}".format(str(i + 1)))
|
||||
with open(file, "w") as f:
|
||||
if i == nj - 1:
|
||||
f.writelines(lines[start:])
|
||||
else:
|
||||
f.writelines(lines[start:end])
|
||||
start = end
|
||||
|
||||
p = Pool(nj)
|
||||
for i in range(nj):
|
||||
p.apply_async(calc_shape_core,
|
||||
args=(shape_path, frontend_conf, speech_length_min, speech_length_max, str(i + 1)))
|
||||
print('Generating shape files, please wait a few minutes...')
|
||||
p.close()
|
||||
p.join()
|
||||
|
||||
# combine
|
||||
file = os.path.join(data_dir, dataset, "speech_shape")
|
||||
with open(file, "w") as f:
|
||||
for i in range(nj):
|
||||
job_file = os.path.join(shape_path, "speech_shape.{}".format(str(i + 1)))
|
||||
with open(job_file) as job_f:
|
||||
lines = job_f.readlines()
|
||||
f.writelines(lines)
|
||||
print('Generating shape files done.')
|
||||
|
||||
|
||||
def generate_data_list(data_dir, dataset, nj=100):
|
||||
split_dir = os.path.join(data_dir, dataset, "split")
|
||||
if os.path.exists(split_dir):
|
||||
assert os.path.exists(os.path.join(data_dir, dataset, "data.list"))
|
||||
print('Data list for large dataset already exists.')
|
||||
return
|
||||
os.makedirs(split_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
|
||||
wav_lines = f_wav.readlines()
|
||||
with open(os.path.join(data_dir, dataset, "text")) as f_text:
|
||||
text_lines = f_text.readlines()
|
||||
total_num_lines = len(wav_lines)
|
||||
num_lines = total_num_lines // nj
|
||||
start_num = 0
|
||||
for i in range(nj):
|
||||
end_num = start_num + num_lines
|
||||
split_dir_nj = os.path.join(split_dir, str(i + 1))
|
||||
os.mkdir(split_dir_nj)
|
||||
wav_file = os.path.join(split_dir_nj, 'wav.scp')
|
||||
text_file = os.path.join(split_dir_nj, "text")
|
||||
with open(wav_file, "w") as fw, open(text_file, "w") as ft:
|
||||
if i == nj - 1:
|
||||
fw.writelines(wav_lines[start_num:])
|
||||
ft.writelines(text_lines[start_num:])
|
||||
else:
|
||||
fw.writelines(wav_lines[start_num:end_num])
|
||||
ft.writelines(text_lines[start_num:end_num])
|
||||
start_num = end_num
|
||||
|
||||
data_list_file = os.path.join(data_dir, dataset, "data.list")
|
||||
with open(data_list_file, "w") as f_data:
|
||||
for i in range(nj):
|
||||
wav_path = os.path.join(split_dir, str(i + 1), "wav.scp")
|
||||
text_path = os.path.join(split_dir, str(i + 1), "text")
|
||||
f_data.write(wav_path + " " + text_path + "\n")
|
||||
|
||||
def filter_wav_text(data_dir, dataset):
|
||||
wav_file = os.path.join(data_dir,dataset,"wav.scp")
|
||||
text_file = os.path.join(data_dir, dataset, "text")
|
||||
with open(wav_file) as f_wav, open(text_file) as f_text:
|
||||
wav_lines = f_wav.readlines()
|
||||
text_lines = f_text.readlines()
|
||||
os.rename(wav_file, "{}.bak".format(wav_file))
|
||||
os.rename(text_file, "{}.bak".format(text_file))
|
||||
wav_dict = {}
|
||||
for line in wav_lines:
|
||||
parts = line.strip().split()
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
sample_name, wav_path = parts
|
||||
wav_dict[sample_name] = wav_path
|
||||
text_dict = {}
|
||||
for line in text_lines:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
sample_name = parts[0]
|
||||
text_dict[sample_name] = " ".join(parts[1:]).lower()
|
||||
filter_count = 0
|
||||
with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
|
||||
for sample_name, wav_path in wav_dict.items():
|
||||
if sample_name in text_dict.keys():
|
||||
f_wav.write(sample_name + " " + wav_path + "\n")
|
||||
f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
|
||||
else:
|
||||
filter_count += 1
|
||||
print("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), filter_count, dataset))
|
||||
14
funasr_local/utils/yaml_no_alias_safe_dump.py
Normal file
14
funasr_local/utils/yaml_no_alias_safe_dump.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import yaml
|
||||
|
||||
|
||||
class NoAliasSafeDumper(yaml.SafeDumper):
|
||||
# Disable anchor/alias in yaml because looks ugly
|
||||
def ignore_aliases(self, data):
|
||||
return True
|
||||
|
||||
|
||||
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
|
||||
"""Safe-dump in yaml with no anchor/alias"""
|
||||
return yaml.dump(
|
||||
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
|
||||
)
|
||||
Reference in New Issue
Block a user