mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
@@ -280,10 +280,14 @@ class Qwen2LM(torch.nn.Module):
|
|||||||
sampling: int,
|
sampling: int,
|
||||||
ignore_eos: bool = True,
|
ignore_eos: bool = True,
|
||||||
):
|
):
|
||||||
|
num_trials, max_trials = 0, 100
|
||||||
while True:
|
while True:
|
||||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
||||||
break
|
break
|
||||||
|
num_trials += 1
|
||||||
|
if num_trials > max_trials:
|
||||||
|
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
||||||
return top_ids
|
return top_ids
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
Reference in New Issue
Block a user