diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 43ced32..f162cbe 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -161,6 +161,7 @@ def is_only_punctuation(text): punctuation_pattern = r'^[\p{P}\p{S}]*$' return bool(regex.fullmatch(punctuation_pattern, text)) + def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: assert mask.dtype == torch.bool assert dtype in [torch.float32, torch.bfloat16, torch.float16]