diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 814e4d1..6d88fc7 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -280,10 +280,14 @@ class Qwen2LM(torch.nn.Module): sampling: int, ignore_eos: bool = True, ): + num_trials, max_trials = 0, 100 while True: top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) if (not ignore_eos) or (self.speech_token_size not in top_ids): break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) return top_ids @torch.inference_mode()