add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

1723
funasr_local/tasks/asr.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,376 @@
import argparse
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.datasets.collate_fn import CommonCollateFn
from funasr_local.datasets.preprocessor import CommonPreprocessor
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.layers.global_mvn import GlobalMVN
from funasr_local.layers.utterance_mvn import UtteranceMVN
from funasr_local.models.data2vec import Data2VecPretrainModel
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.frontend.default import DefaultFrontend
from funasr_local.models.frontend.windowing import SlidingWindow
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr_local.models.preencoder.sinc import LightweightSincConvs
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.models.specaug.specaug import SpecAug
from funasr_local.tasks.abs_task import AbsTask
from funasr_local.text.phoneme_tokenizer import g2p_choices
from funasr_local.torch_utils.initialize import initialize
from funasr_local.train.class_choices import ClassChoices
from funasr_local.train.trainer import Trainer
from funasr_local.utils.types import float_or_none
from funasr_local.utils.types import int_or_none
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str_or_none
frontend_choices = ClassChoices(
name="frontend",
classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(specaug=SpecAug),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
preencoder_choices = ClassChoices(
name="preencoder",
classes=dict(
sinc=LightweightSincConvs,
),
type_check=AbsPreEncoder,
default=None,
optional=True,
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
data2vec_encoder=Data2VecEncoder,
),
type_check=AbsEncoder,
default="data2vec_encoder",
)
model_choices = ClassChoices(
"model",
classes=dict(
data2vec=Data2VecPretrainModel,
),
default="data2vec",
)
class Data2VecTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
encoder_choices,
# --model and --model_conf
model_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default=None,
choices=["bpe", "char", "word", "phn"],
help="The text will be tokenized " "in the specified level token",
)
group.add_argument(
"--feats_type",
type=str,
default='fbank',
help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file of sentencepiece",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
parser.add_argument(
"--pred_masked_weight",
type=float,
default=1.0,
help="weight for predictive loss for masked frames",
)
parser.add_argument(
"--pred_nomask_weight",
type=float,
default=0.0,
help="weight for predictive loss for unmasked frames",
)
parser.add_argument(
"--loss_weights",
type=float,
default=0.0,
help="weights for additional loss terms (not first one)",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
return CommonCollateFn(clipping=True)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
# for pre-training
retval = ("speech",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Pre-encoder input block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
if getattr(args, "preencoder", None) is not None:
preencoder_class = preencoder_choices.get_class(args.preencoder)
preencoder = preencoder_class(**args.preencoder_conf)
input_size = preencoder.output_size()
else:
preencoder = None
# 5. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(
input_size=input_size,
**args.encoder_conf,
)
# 6. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("data2vec")
model = model_class(
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
)
# 7. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model

918
funasr_local/tasks/diar.py Normal file
View File

@@ -0,0 +1,918 @@
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.datasets.collate_fn import CommonCollateFn
from funasr_local.datasets.preprocessor import CommonPreprocessor
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.layers.global_mvn import GlobalMVN
from funasr_local.layers.label_aggregation import LabelAggregate
from funasr_local.layers.utterance_mvn import UtteranceMVN
from funasr_local.models.e2e_diar_sond import DiarSondModel
from funasr_local.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.encoder.conformer_encoder import ConformerEncoder
from funasr_local.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr_local.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr_local.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr_local.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr_local.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr_local.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr_local.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr_local.models.encoder.rnn_encoder import RNNEncoder
from funasr_local.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr_local.models.encoder.transformer_encoder import TransformerEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.frontend.default import DefaultFrontend
from funasr_local.models.frontend.fused import FusedFrontends
from funasr_local.models.frontend.s3prl import S3prlFrontend
from funasr_local.models.frontend.wav_frontend import WavFrontend
from funasr_local.models.frontend.wav_frontend import WavFrontendMel23
from funasr_local.models.frontend.windowing import SlidingWindow
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.models.specaug.specaug import SpecAug
from funasr_local.models.specaug.specaug import SpecAugLFR
from funasr_local.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr_local.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr_local.tasks.abs_task import AbsTask
from funasr_local.torch_utils.initialize import initialize
from funasr_local.train.abs_espnet_model import AbsESPnetModel
from funasr_local.train.class_choices import ClassChoices
from funasr_local.train.trainer import Trainer
from funasr_local.utils.types import float_or_none
from funasr_local.utils.types import int_or_none
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str_or_none
frontend_choices = ClassChoices(
name="frontend",
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
wav_frontend_mel23=WavFrontendMel23,
),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(
specaug=SpecAug,
specaug_lfr=SpecAugLFR,
),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
label_aggregator_choices = ClassChoices(
"label_aggregator",
classes=dict(
label_aggregator=LabelAggregate
),
type_check=torch.nn.Module,
default=None,
optional=True,
)
model_choices = ClassChoices(
"model",
classes=dict(
sond=DiarSondModel,
eend_ola=DiarEENDOLAModel,
),
type_check=AbsESPnetModel,
default="sond",
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
san=SelfAttentionEncoder,
fsmn=FsmnEncoder,
conv=ConvEncoder,
resnet34=ResNet34Diar,
resnet34_sp_l2reg=ResNet34SpL2RegDiar,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
ecapa_tdnn=ECAPA_TDNN,
eend_ola_transformer=EENDOLATransformerEncoder,
),
type_check=torch.nn.Module,
default="resnet34",
)
speaker_encoder_choices = ClassChoices(
"speaker_encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
san=SelfAttentionEncoder,
fsmn=FsmnEncoder,
conv=ConvEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
),
type_check=AbsEncoder,
default=None,
optional=True
)
cd_scorer_choices = ClassChoices(
"cd_scorer",
classes=dict(
san=SelfAttentionEncoder,
),
type_check=AbsEncoder,
default=None,
optional=True,
)
ci_scorer_choices = ClassChoices(
"ci_scorer",
classes=dict(
dot=DotScorer,
cosine=CosScorer,
conv=ConvEncoder,
),
type_check=torch.nn.Module,
default=None,
optional=True,
)
# decoder is used for output (e.g. post_net in SOND)
decoder_choices = ClassChoices(
"decoder",
classes=dict(
rnn=RNNEncoder,
fsmn=FsmnEncoder,
),
type_check=torch.nn.Module,
default="fsmn",
)
# encoder_decoder_attractor is used for EEND-OLA
encoder_decoder_attractor_choices = ClassChoices(
"encoder_decoder_attractor",
classes=dict(
eda=EncoderDecoderAttractor,
),
type_check=torch.nn.Module,
default="eda",
)
class DiarTask(AbsTask):
# If you need more than 1 optimizer, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --label_aggregator and --label_aggregator_conf
label_aggregator_choices,
# --model and --model_conf
model_choices,
# --encoder and --encoder_conf
encoder_choices,
# --speaker_encoder and --speaker_encoder_conf
speaker_encoder_choices,
# --cd_scorer and cd_scorer_conf
cd_scorer_choices,
# --ci_scorer and ci_scorer_conf
ci_scorer_choices,
# --decoder and --decoder_conf
decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--seg_dict_file",
type=str,
default=None,
help="seg_dict_file for text processing",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="char",
choices=["char"],
help="The text will be tokenized in the specified level token",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=None,
non_linguistic_symbols=None,
text_cleaner=None,
g2p_type=None,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", "profile", "binary_labels")
else:
# Recognition mode
retval = ("speech", "profile")
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
# 5. speaker encoder
if getattr(args, "speaker_encoder", None) is not None:
speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
else:
speaker_encoder = None
# 6. CI & CD scorer
if getattr(args, "ci_scorer", None) is not None:
ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
else:
ci_scorer = None
if getattr(args, "cd_scorer", None) is not None:
cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
else:
cd_scorer = None
# 7. Decoder
decoder_class = decoder_choices.get_class(args.decoder)
decoder = decoder_class(**args.decoder_conf)
if getattr(args, "label_aggregator", None) is not None:
label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
else:
label_aggregator = None
# 9. Build model
model_class = model_choices.get_class(args.model)
model = model_class(
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
normalize=normalize,
label_aggregator=label_aggregator,
encoder=encoder,
speaker_encoder=speaker_encoder,
ci_scorer=ci_scorer,
cd_scorer=cd_scorer,
decoder=decoder,
token_list=token_list,
**args.model_conf,
)
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: Union[str, torch.device] = "cpu",
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
cmvn_file: The cmvn file for front-end
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, AbsESPnetModel):
raise RuntimeError(
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
model_name_pth = None
if model_file is not None:
logging.info("model_file is {}".format(model_file))
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
model_dir = os.path.dirname(model_file)
model_name = os.path.basename(model_file)
if "model.ckpt-" in model_name or ".bin" in model_name:
if ".bin" in model_name:
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
else:
model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
if os.path.exists(model_name_pth):
logging.info("model_file is load from pth: {}".format(model_name_pth))
model_dict = torch.load(model_name_pth, map_location=device)
else:
model_dict = cls.convert_tf2torch(model, model_file)
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
model.load_state_dict(model_dict)
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
logging.info("model_file is saved to pth: {}".format(model_name_pth))
return model, args
@classmethod
def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
from collections import OrderedDict
new_dict = OrderedDict()
for key, value in src_dict.items():
if key in dest_dict:
new_dict[key] = value
else:
logging.info("{} is no longer needed in this model.".format(key))
for key, value in dest_dict.items():
if key not in new_dict:
logging.warning("{} is missed in checkpoint.".format(key))
return new_dict
@classmethod
def convert_tf2torch(
cls,
model,
ckpt,
):
logging.info("start convert tf model to torch model")
from funasr_local.modules.streaming_utils.load_fr_tf import load_tf_dict
var_dict_tf = load_tf_dict(ckpt)
var_dict_torch = model.state_dict()
var_dict_torch_update = dict()
# speech encoder
if model.encoder is not None:
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# speaker encoder
if model.speaker_encoder is not None:
var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# cd scorer
if model.cd_scorer is not None:
var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# ci scorer
if model.ci_scorer is not None:
var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# decoder
if model.decoder is not None:
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
class EENDOLADiarTask(AbsTask):
# If you need more than 1 optimizer, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
model_choices,
# --encoder and --encoder_conf
encoder_choices,
# --speaker_encoder and --speaker_encoder_conf
encoder_decoder_attractor_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--seg_dict_file",
type=str,
default=None,
help="seg_dict_file for text processing",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="char",
choices=["char"],
help="The text will be tokenized in the specified level token",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
# if args.use_preprocessor:
# retval = CommonPreprocessor(
# train=train,
# token_type=args.token_type,
# token_list=args.token_list,
# bpemodel=None,
# non_linguistic_symbols=None,
# text_cleaner=None,
# g2p_type=None,
# split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
# seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
# # NOTE(kamo): Check attribute existence for backward compatibility
# rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
# rir_apply_prob=args.rir_apply_prob
# if hasattr(args, "rir_apply_prob")
# else 1.0,
# noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
# noise_apply_prob=args.noise_apply_prob
# if hasattr(args, "noise_apply_prob")
# else 1.0,
# noise_db_range=args.noise_db_range
# if hasattr(args, "noise_db_range")
# else "13_15",
# speech_volume_normalize=args.speech_volume_normalize
# if hasattr(args, "rir_scp")
# else None,
# )
# else:
# retval = None
# assert check_return_type(retval)
return None
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", )
else:
# Recognition mode
retval = ("speech", )
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
# 1. frontend
if args.input_size is None or args.frontend == "wav_frontend_mel23":
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(**args.encoder_conf)
# 3. EncoderDecoderAttractor
encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
# 9. Build model
model_class = model_choices.get_class(args.model)
model = model_class(
frontend=frontend,
encoder=encoder,
encoder_decoder_attractor=encoder_decoder_attractor,
**args.model_conf,
)
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
cmvn_file: The cmvn file for front-end
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, AbsESPnetModel):
raise RuntimeError(
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
)
if model_file is not None:
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
checkpoint = torch.load(model_file, map_location=device)
if "state_dict" in checkpoint.keys():
model.load_state_dict(checkpoint["state_dict"])
else:
model.load_state_dict(checkpoint)
model.to(device)
return model, args

211
funasr_local/tasks/lm.py Normal file
View File

@@ -0,0 +1,211 @@
import argparse
import logging
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.datasets.collate_fn import CommonCollateFn
from funasr_local.datasets.preprocessor import CommonPreprocessor
from funasr_local.lm.abs_model import AbsLM
from funasr_local.lm.abs_model import LanguageModel
from funasr_local.lm.seq_rnn_lm import SequentialRNNLM
from funasr_local.lm.transformer_lm import TransformerLM
from funasr_local.tasks.abs_task import AbsTask
from funasr_local.text.phoneme_tokenizer import g2p_choices
from funasr_local.torch_utils.initialize import initialize
from funasr_local.train.class_choices import ClassChoices
from funasr_local.train.trainer import Trainer
from funasr_local.utils.get_default_kwargs import get_default_kwargs
from funasr_local.utils.nested_dict_action import NestedDictAction
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str_or_none
lm_choices = ClassChoices(
"lm",
classes=dict(
seq_rnn=SequentialRNNLM,
transformer=TransformerLM,
),
type_check=AbsLM,
default="seq_rnn",
)
class LMTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [lm_choices]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
# NOTE(kamo): Use '_' instead of '-' to avoid confusion
assert check_argument_types()
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(LanguageModel),
help="The keyword arguments for model class.",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word"],
help="",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file fo sentencepiece",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
for class_choices in cls.class_choices_list:
class_choices.add_arguments(group)
assert check_return_type(parser)
return parser
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
return CommonCollateFn(int_pad_value=0)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
non_linguistic_symbols=args.non_linguistic_symbols,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ("text",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> LanguageModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# "args" is saved as it is in a yaml file by BaseTask.main().
# Overwriting token_list to keep it as "portable".
args.token_list = token_list.copy()
elif isinstance(args.token_list, (tuple, list)):
token_list = args.token_list.copy()
else:
raise RuntimeError("token_list must be str or dict")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
# 1. Build LM model
lm_class = lm_choices.get_class(args.lm)
lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
# 2. Build ESPnetModel
# Assume the last-id is sos_and_eos
model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
# 3. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model

View File

@@ -0,0 +1,229 @@
import argparse
import logging
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.datasets.collate_fn import CommonCollateFn
from funasr_local.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
from funasr_local.train.abs_model import AbsPunctuation
from funasr_local.train.abs_model import PunctuationModel
from funasr_local.models.target_delay_transformer import TargetDelayTransformer
from funasr_local.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr_local.tasks.abs_task import AbsTask
from funasr_local.text.phoneme_tokenizer import g2p_choices
from funasr_local.torch_utils.initialize import initialize
from funasr_local.train.class_choices import ClassChoices
from funasr_local.train.trainer import Trainer
from funasr_local.utils.get_default_kwargs import get_default_kwargs
from funasr_local.utils.nested_dict_action import NestedDictAction
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str_or_none
punc_choices = ClassChoices(
"punctuation",
classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
type_check=AbsPunctuation,
default="target_delay",
)
class PunctuationTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [punc_choices]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
# NOTE(kamo): Use '_' instead of '-' to avoid confusion
assert check_argument_types()
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(PunctuationModel),
help="The keyword arguments for model class.",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word"],
help="",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file fo sentencepiece",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
assert check_return_type(parser)
return parser
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
return CommonCollateFn(int_pad_value=0)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
token_types = [args.token_type, args.token_type]
token_lists = [args.token_list, args.punc_list]
bpemodels = [args.bpemodel, args.bpemodel]
text_names = ["text", "punc"]
if args.use_preprocessor:
retval = PuncTrainTokenizerCommonPreprocessor(
train=train,
token_type=token_types,
token_list=token_lists,
bpemodel=bpemodels,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
text_name = text_names,
non_linguistic_symbols=args.non_linguistic_symbols,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ("text", "punc")
if inference:
retval = ("text", )
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ("vad",)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# "args" is saved as it is in a yaml file by BaseTask.main().
# Overwriting token_list to keep it as "portable".
args.token_list = token_list.copy()
if isinstance(args.punc_list, str):
with open(args.punc_list, encoding="utf-8") as f2:
pairs = [line.rstrip().split(":") for line in f2]
punc_list = [pair[0] for pair in pairs]
punc_weight_list = [float(pair[1]) for pair in pairs]
args.punc_list = punc_list.copy()
elif isinstance(args.punc_list, list):
punc_list = args.punc_list.copy()
punc_weight_list = [1] * len(punc_list)
if isinstance(args.token_list, (tuple, list)):
token_list = args.token_list.copy()
else:
raise RuntimeError("token_list must be str or dict")
vocab_size = len(token_list)
punc_size = len(punc_list)
logging.info(f"Vocabulary size: {vocab_size}")
# 1. Build PUNC model
punc_class = punc_choices.get_class(args.punctuation)
punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
# 2. Build ESPnetModel
# Assume the last-id is sos_and_eos
if "punc_weight" in args.model_conf:
args.model_conf.pop("punc_weight")
model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
# FIXME(kamo): Should be done in model?
# 3. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model

545
funasr_local/tasks/sv.py Normal file
View File

@@ -0,0 +1,545 @@
"""
Author: Speech Lab, Alibaba Group, China
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.datasets.collate_fn import CommonCollateFn
from funasr_local.datasets.preprocessor import CommonPreprocessor
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.layers.global_mvn import GlobalMVN
from funasr_local.layers.utterance_mvn import UtteranceMVN
from funasr_local.models.e2e_asr import ESPnetASRModel
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.encoder.rnn_encoder import RNNEncoder
from funasr_local.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
from funasr_local.models.pooling.statistic_pooling import StatisticPooling
from funasr_local.models.decoder.sv_decoder import DenseDecoder
from funasr_local.models.e2e_sv import ESPnetSVModel
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.frontend.default import DefaultFrontend
from funasr_local.models.frontend.fused import FusedFrontends
from funasr_local.models.frontend.s3prl import S3prlFrontend
from funasr_local.models.frontend.windowing import SlidingWindow
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr_local.models.postencoder.hugging_face_transformers_postencoder import (
HuggingFaceTransformersPostEncoder, # noqa: H301
)
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr_local.models.preencoder.linear import LinearProjection
from funasr_local.models.preencoder.sinc import LightweightSincConvs
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.models.specaug.specaug import SpecAug
from funasr_local.tasks.abs_task import AbsTask
from funasr_local.torch_utils.initialize import initialize
from funasr_local.train.abs_espnet_model import AbsESPnetModel
from funasr_local.train.class_choices import ClassChoices
from funasr_local.train.trainer import Trainer
from funasr_local.utils.types import float_or_none
from funasr_local.utils.types import int_or_none
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str_or_none
from funasr_local.models.frontend.wav_frontend import WavFrontend
frontend_choices = ClassChoices(
name="frontend",
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(
specaug=SpecAug,
),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
model_choices = ClassChoices(
"model",
classes=dict(
espnet=ESPnetSVModel,
),
type_check=AbsESPnetModel,
default="espnet",
)
preencoder_choices = ClassChoices(
name="preencoder",
classes=dict(
sinc=LightweightSincConvs,
linear=LinearProjection,
),
type_check=AbsPreEncoder,
default=None,
optional=True,
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
resnet34=ResNet34,
resnet34_sp_l2reg=ResNet34_SP_L2Reg,
rnn=RNNEncoder,
),
type_check=AbsEncoder,
default="resnet34",
)
postencoder_choices = ClassChoices(
name="postencoder",
classes=dict(
hugging_face_transformers=HuggingFaceTransformersPostEncoder,
),
type_check=AbsPostEncoder,
default=None,
optional=True,
)
pooling_choices = ClassChoices(
name="pooling_type",
classes=dict(
statistic=StatisticPooling,
),
type_check=torch.nn.Module,
default="statistic",
)
decoder_choices = ClassChoices(
"decoder",
classes=dict(
dense=DenseDecoder,
),
type_check=AbsDecoder,
default="dense",
)
class SVTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
encoder_choices,
# --postencoder and --postencoder_conf
postencoder_choices,
# --pooling and --pooling_conf
pooling_choices,
# --decoder and --decoder_conf
decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to speaker name",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=None,
token_list=None,
bpemodel=None,
non_linguistic_symbols=None,
text_cleaner=args.cleaner,
g2p_type=None,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", "text")
else:
# Recognition mode
retval = ("speech",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
if inference:
retval = ("ref_speech",)
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetSVModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Speaker number: {vocab_size}")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Pre-encoder input block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
if getattr(args, "preencoder", None) is not None:
preencoder_class = preencoder_choices.get_class(args.preencoder)
preencoder = preencoder_class(**args.preencoder_conf)
input_size = preencoder.output_size()
else:
preencoder = None
# 5. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
# 6. Post-encoder block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
encoder_output_size = encoder.output_size()
if getattr(args, "postencoder", None) is not None:
postencoder_class = postencoder_choices.get_class(args.postencoder)
postencoder = postencoder_class(
input_size=encoder_output_size, **args.postencoder_conf
)
encoder_output_size = postencoder.output_size()
else:
postencoder = None
# 7. Pooling layer
pooling_class = pooling_choices.get_class(args.pooling_type)
pooling_dim = (2, 3)
eps = 1e-12
if hasattr(args, "pooling_type_conf"):
if "pooling_dim" in args.pooling_type_conf:
pooling_dim = args.pooling_type_conf["pooling_dim"]
if "eps" in args.pooling_type_conf:
eps = args.pooling_type_conf["eps"]
pooling_layer = pooling_class(
pooling_dim=pooling_dim,
eps=eps,
)
if args.pooling_type == "statistic":
encoder_output_size *= 2
# 8. Decoder
decoder_class = decoder_choices.get_class(args.decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**args.decoder_conf,
)
# 7. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("espnet")
model = model_class(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
pooling_layer=pooling_layer,
decoder=decoder,
**args.model_conf,
)
# FIXME(kamo): Should be done in model?
# 8. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
cmvn_file: The cmvn file for front-end
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, AbsESPnetModel):
raise RuntimeError(
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
model_name_pth = None
if model_file is not None:
logging.info("model_file is {}".format(model_file))
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
model_dir = os.path.dirname(model_file)
model_name = os.path.basename(model_file)
if "model.ckpt-" in model_name or ".bin" in model_name:
if ".bin" in model_name:
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
else:
model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
if os.path.exists(model_name_pth):
logging.info("model_file is load from pth: {}".format(model_name_pth))
model_dict = torch.load(model_name_pth, map_location=device)
else:
model_dict = cls.convert_tf2torch(model, model_file)
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
model.load_state_dict(model_dict)
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
logging.info("model_file is saved to pth: {}".format(model_name_pth))
return model, args
@classmethod
def convert_tf2torch(
cls,
model,
ckpt,
):
logging.info("start convert tf model to torch model")
from funasr_local.modules.streaming_utils.load_fr_tf import load_tf_dict
var_dict_tf = load_tf_dict(ckpt)
var_dict_torch = model.state_dict()
var_dict_torch_update = dict()
# speech encoder
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# pooling layer
var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# decoder
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update

363
funasr_local/tasks/vad.py Normal file
View File

@@ -0,0 +1,363 @@
import argparse
import logging
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import os
from pathlib import Path
from typing import Tuple
from typing import Union
import yaml
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr_local.datasets.collate_fn import CommonCollateFn
from funasr_local.datasets.preprocessor import CommonPreprocessor
from funasr_local.models.ctc import CTC
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.models.decoder.rnn_decoder import RNNDecoder
from funasr_local.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
)
from funasr_local.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr_local.models.decoder.transformer_decoder import (
LightweightConvolution2DTransformerDecoder, # noqa: H301
)
from funasr_local.models.decoder.transformer_decoder import (
LightweightConvolutionTransformerDecoder, # noqa: H301
)
from funasr_local.models.decoder.transformer_decoder import TransformerDecoder
from funasr_local.models.encoder.abs_encoder import AbsEncoder
from funasr_local.models.encoder.conformer_encoder import ConformerEncoder
from funasr_local.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr_local.models.encoder.rnn_encoder import RNNEncoder
from funasr_local.models.encoder.transformer_encoder import TransformerEncoder
from funasr_local.models.frontend.abs_frontend import AbsFrontend
from funasr_local.models.frontend.default import DefaultFrontend
from funasr_local.models.frontend.fused import FusedFrontends
from funasr_local.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr_local.models.frontend.s3prl import S3prlFrontend
from funasr_local.models.frontend.windowing import SlidingWindow
from funasr_local.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr_local.models.postencoder.hugging_face_transformers_postencoder import (
HuggingFaceTransformersPostEncoder, # noqa: H301
)
from funasr_local.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr_local.models.preencoder.linear import LinearProjection
from funasr_local.models.preencoder.sinc import LightweightSincConvs
from funasr_local.models.specaug.abs_specaug import AbsSpecAug
from funasr_local.models.specaug.specaug import SpecAug
from funasr_local.layers.abs_normalize import AbsNormalize
from funasr_local.layers.global_mvn import GlobalMVN
from funasr_local.layers.utterance_mvn import UtteranceMVN
from funasr_local.tasks.abs_task import AbsTask
from funasr_local.text.phoneme_tokenizer import g2p_choices
from funasr_local.train.abs_espnet_model import AbsESPnetModel
from funasr_local.train.class_choices import ClassChoices
from funasr_local.train.trainer import Trainer
from funasr_local.utils.get_default_kwargs import get_default_kwargs
from funasr_local.utils.nested_dict_action import NestedDictAction
from funasr_local.utils.types import float_or_none
from funasr_local.utils.types import int_or_none
from funasr_local.utils.types import str2bool
from funasr_local.utils.types import str_or_none
from funasr_local.models.specaug.specaug import SpecAugLFR
from funasr_local.models.predictor.cif import CifPredictor, CifPredictorV2
from funasr_local.modules.subsampling import Conv1dSubsampling
from funasr_local.models.e2e_vad import E2EVadModel
from funasr_local.models.encoder.fsmn_encoder import FSMN
frontend_choices = ClassChoices(
name="frontend",
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
wav_frontend_online=WavFrontendOnline,
),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(
specaug=SpecAug,
specaug_lfr=SpecAugLFR,
),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
model_choices = ClassChoices(
"model",
classes=dict(
e2evad=E2EVadModel,
),
type_check=object,
default="e2evad",
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
fsmn=FSMN,
),
type_check=torch.nn.Module,
default="fsmn",
)
class VADTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --model and --model_conf
model_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
# if args.use_preprocessor:
# retval = CommonPreprocessor(
# train=train,
# # NOTE(kamo): Check attribute existence for backward compatibility
# rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
# rir_apply_prob=args.rir_apply_prob
# if hasattr(args, "rir_apply_prob")
# else 1.0,
# noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
# noise_apply_prob=args.noise_apply_prob
# if hasattr(args, "noise_apply_prob")
# else 1.0,
# noise_db_range=args.noise_db_range
# if hasattr(args, "noise_db_range")
# else "13_15",
# speech_volume_normalize=args.speech_volume_normalize
# if hasattr(args, "rir_scp")
# else None,
# )
# else:
# retval = None
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", "text")
else:
# Recognition mode
retval = ("speech",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
# 4. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(**args.encoder_conf)
# 5. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("e2evad")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
device: str = "cpu",
cmvn_file: Union[Path, str] = None,
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
#if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
model.to(device)
model_dict = dict()
model_name_pth = None
if model_file is not None:
logging.info("model_file is {}".format(model_file))
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
model_dir = os.path.dirname(model_file)
model_name = os.path.basename(model_file)
model_dict = torch.load(model_file, map_location=device)
model.encoder.load_state_dict(model_dict)
return model, args