mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
156
funasr_local/modules/rnn/argument.py
Normal file
156
funasr_local/modules/rnn/argument.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright 2020 Hirofumi Inaguma
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Conformer common arguments."""
|
||||
|
||||
|
||||
def add_arguments_rnn_encoder_common(group):
|
||||
"""Define common arguments for RNN encoder."""
|
||||
group.add_argument(
|
||||
"--etype",
|
||||
default="blstmp",
|
||||
type=str,
|
||||
choices=[
|
||||
"lstm",
|
||||
"blstm",
|
||||
"lstmp",
|
||||
"blstmp",
|
||||
"vgglstmp",
|
||||
"vggblstmp",
|
||||
"vgglstm",
|
||||
"vggblstm",
|
||||
"gru",
|
||||
"bgru",
|
||||
"grup",
|
||||
"bgrup",
|
||||
"vgggrup",
|
||||
"vggbgrup",
|
||||
"vgggru",
|
||||
"vggbgru",
|
||||
],
|
||||
help="Type of encoder network architecture",
|
||||
)
|
||||
group.add_argument(
|
||||
"--elayers",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of encoder layers",
|
||||
)
|
||||
group.add_argument(
|
||||
"--eunits",
|
||||
"-u",
|
||||
default=300,
|
||||
type=int,
|
||||
help="Number of encoder hidden units",
|
||||
)
|
||||
group.add_argument(
|
||||
"--eprojs", default=320, type=int, help="Number of encoder projection units"
|
||||
)
|
||||
group.add_argument(
|
||||
"--subsample",
|
||||
default="1",
|
||||
type=str,
|
||||
help="Subsample input frames x_y_z means "
|
||||
"subsample every x frame at 1st layer, "
|
||||
"every y frame at 2nd layer etc.",
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
def add_arguments_rnn_decoder_common(group):
|
||||
"""Define common arguments for RNN decoder."""
|
||||
group.add_argument(
|
||||
"--dtype",
|
||||
default="lstm",
|
||||
type=str,
|
||||
choices=["lstm", "gru"],
|
||||
help="Type of decoder network architecture",
|
||||
)
|
||||
group.add_argument(
|
||||
"--dlayers", default=1, type=int, help="Number of decoder layers"
|
||||
)
|
||||
group.add_argument(
|
||||
"--dunits", default=320, type=int, help="Number of decoder hidden units"
|
||||
)
|
||||
group.add_argument(
|
||||
"--dropout-rate-decoder",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Dropout rate for the decoder",
|
||||
)
|
||||
group.add_argument(
|
||||
"--sampling-probability",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Ratio of predicted labels fed back to decoder",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lsm-type",
|
||||
const="",
|
||||
default="",
|
||||
type=str,
|
||||
nargs="?",
|
||||
choices=["", "unigram"],
|
||||
help="Apply label smoothing with a specified distribution type",
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
def add_arguments_rnn_attention_common(group):
|
||||
"""Define common arguments for RNN attention."""
|
||||
group.add_argument(
|
||||
"--atype",
|
||||
default="dot",
|
||||
type=str,
|
||||
choices=[
|
||||
"noatt",
|
||||
"dot",
|
||||
"add",
|
||||
"location",
|
||||
"coverage",
|
||||
"coverage_location",
|
||||
"location2d",
|
||||
"location_recurrent",
|
||||
"multi_head_dot",
|
||||
"multi_head_add",
|
||||
"multi_head_loc",
|
||||
"multi_head_multi_res_loc",
|
||||
],
|
||||
help="Type of attention architecture",
|
||||
)
|
||||
group.add_argument(
|
||||
"--adim",
|
||||
default=320,
|
||||
type=int,
|
||||
help="Number of attention transformation dimensions",
|
||||
)
|
||||
group.add_argument(
|
||||
"--awin", default=5, type=int, help="Window size for location2d attention"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aheads",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of heads for multi head attention",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aconv-chans",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="Number of attention convolution channels \
|
||||
(negative value indicates no location-aware attention)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aconv-filts",
|
||||
default=100,
|
||||
type=int,
|
||||
help="Number of attention convolution filters \
|
||||
(negative value indicates no location-aware attention)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--dropout-rate",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Dropout rate for the encoder",
|
||||
)
|
||||
return group
|
||||
Reference in New Issue
Block a user