fix trt wrapper bug

This commit is contained in:
lyuxiang.lx
2025-05-26 18:03:15 +08:00
parent 68100c267a
commit 3e12bb86bd
3 changed files with 14 additions and 2 deletions

View File

@@ -230,6 +230,9 @@ def add_optional_chunk_mask(xs: torch.Tensor,
else:
chunk_masks = masks
assert chunk_masks.dtype == torch.bool
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
return chunk_masks