fix possible bug

This commit is contained in:
lyuxiang.lx
2025-06-25 14:31:05 +08:00
parent 63856565f3
commit 46dfe0439b

View File

@@ -127,6 +127,8 @@ class ConditionalCFM(BASECFM):
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
else: else:
[estimator, stream], trt_engine = self.estimator.acquire_estimator() [estimator, stream], trt_engine = self.estimator.acquire_estimator()
# NOTE need to synchronize when switching stream
torch.cuda.current_stream().synchronize()
with stream: with stream:
estimator.set_input_shape('x', (2, 80, x.size(2))) estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2))) estimator.set_input_shape('mask', (2, 1, x.size(2)))