This commit is contained in:
zhoubofan.zbf
2024-08-30 13:43:54 +08:00
parent 6e7f5b922a
commit 29408360fb
4 changed files with 18 additions and 16 deletions

View File

@@ -99,10 +99,10 @@ class ConditionalCFM(BASECFM):
def forward_estimator(self, x, mask, mu, t, spks, cond):
if not isinstance(self.estimator, torch.nn.Module):
if self.estimator is not None:
return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
print("-----------")
assert self.training is False, 'tensorrt cannot be used in training'
bs = x.shape[0]
hs = x.shape[1]
@@ -119,10 +119,10 @@ class ConditionalCFM(BASECFM):
names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out']
for i in range(len(bindings)):
self.estimator.set_tensor_address(names[i], bindings[i])
self.estimator_context.set_tensor_address(names[i], bindings[i])
handle = torch.cuda.current_stream().cuda_stream
self.estimator.execute_async_v3(stream_handle=handle)
self.estimator_context.execute_async_v3(stream_handle=handle)
return ret
def compute_loss(self, x1, mask, mu, spks=None, cond=None):