mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
283
funasr_local/bin/tokenize_text.py
Normal file
283
funasr_local/bin/tokenize_text.py
Normal file
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
from collections import Counter
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.utils.cli_utils import get_commandline_args
|
||||
from funasr_local.text.build_tokenizer import build_tokenizer
|
||||
from funasr_local.text.cleaner import TextCleaner
|
||||
from funasr_local.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr_local.utils.types import str2bool
|
||||
from funasr_local.utils.types import str_or_none
|
||||
|
||||
|
||||
def field2slice(field: Optional[str]) -> slice:
|
||||
"""Convert field string to slice
|
||||
|
||||
Note that field string accepts 1-based integer.
|
||||
|
||||
Examples:
|
||||
>>> field2slice("1-")
|
||||
slice(0, None, None)
|
||||
>>> field2slice("1-3")
|
||||
slice(0, 3, None)
|
||||
>>> field2slice("-3")
|
||||
slice(None, 3, None)
|
||||
"""
|
||||
field = field.strip()
|
||||
try:
|
||||
if "-" in field:
|
||||
# e.g. "2-" or "2-5" or "-7"
|
||||
s1, s2 = field.split("-", maxsplit=1)
|
||||
if s1.strip() == "":
|
||||
s1 = None
|
||||
else:
|
||||
s1 = int(s1)
|
||||
if s1 == 0:
|
||||
raise ValueError("1-based string")
|
||||
if s2.strip() == "":
|
||||
s2 = None
|
||||
else:
|
||||
s2 = int(s2)
|
||||
else:
|
||||
# e.g. "2"
|
||||
s1 = int(field)
|
||||
s2 = s1 + 1
|
||||
if s1 == 0:
|
||||
raise ValueError("must be 1 or more value")
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
|
||||
|
||||
if s1 is None:
|
||||
slic = slice(None, s2)
|
||||
else:
|
||||
# -1 because of 1-based integer following "cut" command
|
||||
# e.g "1-3" -> slice(0, 3)
|
||||
slic = slice(s1 - 1, s2)
|
||||
return slic
|
||||
|
||||
|
||||
def tokenize(
|
||||
input: str,
|
||||
output: str,
|
||||
field: Optional[str],
|
||||
delimiter: Optional[str],
|
||||
token_type: str,
|
||||
space_symbol: str,
|
||||
non_linguistic_symbols: Optional[str],
|
||||
bpemodel: Optional[str],
|
||||
log_level: str,
|
||||
write_vocabulary: bool,
|
||||
vocabulary_size: int,
|
||||
remove_non_linguistic_symbols: bool,
|
||||
cutoff: int,
|
||||
add_symbol: List[str],
|
||||
cleaner: Optional[str],
|
||||
g2p: Optional[str],
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
if input == "-":
|
||||
fin = sys.stdin
|
||||
else:
|
||||
fin = Path(input).open("r", encoding="utf-8")
|
||||
if output == "-":
|
||||
fout = sys.stdout
|
||||
else:
|
||||
p = Path(output)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
fout = p.open("w", encoding="utf-8")
|
||||
|
||||
cleaner = TextCleaner(cleaner)
|
||||
tokenizer = build_tokenizer(
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
delimiter=delimiter,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
|
||||
g2p_type=g2p,
|
||||
)
|
||||
|
||||
counter = Counter()
|
||||
if field is not None:
|
||||
field = field2slice(field)
|
||||
|
||||
for line in fin:
|
||||
line = line.rstrip()
|
||||
if field is not None:
|
||||
# e.g. field="2-"
|
||||
# uttidA hello world!! -> hello world!!
|
||||
tokens = line.split(delimiter)
|
||||
tokens = tokens[field]
|
||||
if delimiter is None:
|
||||
line = " ".join(tokens)
|
||||
else:
|
||||
line = delimiter.join(tokens)
|
||||
|
||||
line = cleaner(line)
|
||||
tokens = tokenizer.text2tokens(line)
|
||||
if not write_vocabulary:
|
||||
fout.write(" ".join(tokens) + "\n")
|
||||
else:
|
||||
for t in tokens:
|
||||
counter[t] += 1
|
||||
|
||||
if not write_vocabulary:
|
||||
return
|
||||
|
||||
## FIXME
|
||||
## del duplicate add_symbols in counter
|
||||
for symbol_and_id in add_symbol:
|
||||
# e.g symbol="<blank>:0"
|
||||
try:
|
||||
symbol, idx = symbol_and_id.split(":")
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
|
||||
symbol = symbol.strip()
|
||||
if symbol in counter:
|
||||
del counter[symbol]
|
||||
|
||||
# ======= write_vocabulary mode from here =======
|
||||
# Sort by the number of occurrences in descending order
|
||||
# and filter lower frequency words than cutoff value
|
||||
words_and_counts = list(
|
||||
filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
|
||||
)
|
||||
# Restrict the vocabulary size
|
||||
if vocabulary_size > 0:
|
||||
if vocabulary_size < len(add_symbol):
|
||||
raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
|
||||
words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
|
||||
|
||||
# Parse the values of --add_symbol
|
||||
for symbol_and_id in add_symbol:
|
||||
# e.g symbol="<blank>:0"
|
||||
try:
|
||||
symbol, idx = symbol_and_id.split(":")
|
||||
idx = int(idx)
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
|
||||
symbol = symbol.strip()
|
||||
|
||||
# e.g. idx=0 -> append as the first symbol
|
||||
# e.g. idx=-1 -> append as the last symbol
|
||||
if idx < 0:
|
||||
idx = len(words_and_counts) + 1 + idx
|
||||
words_and_counts.insert(idx, (symbol, None))
|
||||
|
||||
# Write words
|
||||
for w, c in words_and_counts:
|
||||
fout.write(w + "\n")
|
||||
|
||||
# Logging
|
||||
total_count = sum(counter.values())
|
||||
invocab_count = sum(c for w, c in words_and_counts if c is not None)
|
||||
logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tokenize texts",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input", "-i", required=True, help="Input text. - indicates sys.stdin"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True, help="Output text. - indicates sys.stdout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--field",
|
||||
"-f",
|
||||
help="The target columns of the input text as 1-based integer. e.g 2-",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_type",
|
||||
"-t",
|
||||
default="char",
|
||||
choices=["char", "bpe", "word", "phn"],
|
||||
help="Token type",
|
||||
)
|
||||
parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
|
||||
parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
|
||||
parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
|
||||
parser.add_argument(
|
||||
"--non_linguistic_symbols",
|
||||
type=str_or_none,
|
||||
help="non_linguistic_symbols file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_non_linguistic_symbols",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Remove non-language-symbols from tokens",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cleaner",
|
||||
type=str_or_none,
|
||||
choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
|
||||
default=None,
|
||||
help="Apply text cleaning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g2p",
|
||||
type=str_or_none,
|
||||
choices=g2p_choices,
|
||||
default=None,
|
||||
help="Specify g2p method if --token_type=phn",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("write_vocabulary mode related")
|
||||
group.add_argument(
|
||||
"--write_vocabulary",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Write tokens list instead of tokenized text per line",
|
||||
)
|
||||
group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
|
||||
group.add_argument(
|
||||
"--cutoff",
|
||||
default=0,
|
||||
type=int,
|
||||
help="cut-off frequency used for write-vocabulary mode",
|
||||
)
|
||||
group.add_argument(
|
||||
"--add_symbol",
|
||||
type=str,
|
||||
default=[],
|
||||
action="append",
|
||||
help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
tokenize(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user