This commit is contained in:
root
2025-07-29 08:40:51 +00:00
parent d1c354eac7
commit 62d082634e
7 changed files with 71 additions and 68 deletions

View File

@@ -102,6 +102,7 @@ import string
punctuation_all = punctuation + string.punctuation
Pathlike = Union[str, Path]
def remove_punctuation(text: str) -> str:
for x in punctuation_all:
if x == '\'':
@@ -109,6 +110,7 @@ def remove_punctuation(text: str) -> str:
text = text.replace(x, '')
return text
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
@@ -304,6 +306,7 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -533,7 +536,7 @@ def get_args():
default=None,
help="wav_base_name label",
)
# Dataset related arguments for loading labels when label file is not provided
parser.add_argument(
"--dataset-name",
@@ -541,14 +544,14 @@ def get_args():
default="yuekai/seed_tts_cosy2",
help="Huggingface dataset name for loading labels",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="Dataset split name for loading labels",
)
return parser.parse_args()
@@ -590,7 +593,7 @@ def normalize_text_alimeeting(text: str) -> str:
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
import re
text = text.replace('\u00A0', '') # test_hard
text = text.replace('\u00A0', '') # test_hard
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
@@ -685,10 +688,10 @@ def main():
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
# Load labels either from file or from dataset
labels_dict = {}
if args.label:
# Load labels from file (original functionality)
print(f"Loading labels from file: {args.label}")
@@ -716,11 +719,11 @@ def main():
split=args.split_name,
trust_remote_code=True,
)
for item in dataset:
audio_id = item["id"]
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
print(f"Loaded {len(labels_dict)} labels from dataset")
# Perform evaluation if labels are available
@@ -750,4 +753,4 @@ def main():
if __name__ == "__main__":
main()
main()