From b60c37b31ae0409de078d798c99fbfa5f187146d Mon Sep 17 00:00:00 2001 From: 0xCAFEBABE0 Date: Mon, 30 Dec 2024 10:48:43 +0800 Subject: [PATCH] fix(bug).when generating text that contains only punctuation marks or whitespace characters, the CPU usage reaches 100%, and the process crashes. --- cosyvoice/cli/frontend.py | 6 +++++- cosyvoice/cli/model.py | 5 ----- cosyvoice/utils/common.py | 7 ------- cosyvoice/utils/frontend_utils.py | 7 +++++++ 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 9885a0f..31926f0 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -31,7 +31,7 @@ except ImportError: from tn.chinese.normalizer import Normalizer as ZhNormalizer from tn.english.normalizer import Normalizer as EnNormalizer use_ttsfrd = False -from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph +from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation class CosyVoiceFrontEnd: @@ -109,6 +109,10 @@ class CosyVoiceFrontEnd: def text_normalize(self, text, split=True): text = text.strip() + # When generating text that contains only punctuation marks or whitespace characters + # - Returning empty texts ensures consistent processing logic. + if is_only_punctuation(text): + return [] if contains_chinese(text): if self.use_ttsfrd: texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]] diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 5314e8b..b9f555b 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -19,7 +19,6 @@ from torch.nn import functional as F from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out -from cosyvoice.utils.common import is_only_punctuation class CosyVoiceModel: @@ -146,10 +145,6 @@ class CosyVoiceModel: llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): - # When generating text that contains only punctuation marks or whitespace characters - # - Returning 10ms of silence ensures consistent processing logic. - if is_only_punctuation(text): - return {'tts_speech': torch.zeros(1, int(0.01 * 22050))} # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) with self.lock: diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index f162cbe..b356f0c 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -20,7 +20,6 @@ from typing import List import numpy as np import torch -import regex IGNORE_ID = -1 @@ -156,12 +155,6 @@ def set_all_random_seed(seed): torch.cuda.manual_seed_all(seed) -def is_only_punctuation(text): - # Regular expression: Match strings that consist only of punctuation marks or are empty. - punctuation_pattern = r'^[\p{P}\p{S}]*$' - return bool(regex.fullmatch(punctuation_pattern, text)) - - def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: assert mask.dtype == torch.bool assert dtype in [torch.float32, torch.bfloat16, torch.float16] diff --git a/cosyvoice/utils/frontend_utils.py b/cosyvoice/utils/frontend_utils.py index ab01a1f..ea1c9fc 100644 --- a/cosyvoice/utils/frontend_utils.py +++ b/cosyvoice/utils/frontend_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import re +import regex chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') @@ -127,3 +128,9 @@ def replace_blank(text: str): else: out_str.append(c) return "".join(out_str) + + +def is_only_punctuation(text): + # Regular expression: Match strings that consist only of punctuation marks or are empty. + punctuation_pattern = r'^[\p{P}\p{S}]*$' + return bool(regex.fullmatch(punctuation_pattern, text))