Merge pull request #811 from FunAudioLLM/dev/lyuxiang.lx

update
This commit is contained in:
Xiang Lyu
2024-12-30 17:41:58 +08:00
committed by GitHub

View File

@@ -280,10 +280,14 @@ class Qwen2LM(torch.nn.Module):
sampling: int,
ignore_eos: bool = True,
):
num_trials, max_trials = 0, 100
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
num_trials += 1
if num_trials > max_trials:
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
return top_ids
@torch.inference_mode()