mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix bug
This commit is contained in:
@@ -99,10 +99,10 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
|
||||
if not isinstance(self.estimator, torch.nn.Module):
|
||||
if self.estimator is not None:
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
|
||||
else:
|
||||
print("-----------")
|
||||
assert self.training is False, 'tensorrt cannot be used in training'
|
||||
bs = x.shape[0]
|
||||
hs = x.shape[1]
|
||||
@@ -119,10 +119,10 @@ class ConditionalCFM(BASECFM):
|
||||
names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out']
|
||||
|
||||
for i in range(len(bindings)):
|
||||
self.estimator.set_tensor_address(names[i], bindings[i])
|
||||
self.estimator_context.set_tensor_address(names[i], bindings[i])
|
||||
|
||||
handle = torch.cuda.current_stream().cuda_stream
|
||||
self.estimator.execute_async_v3(stream_handle=handle)
|
||||
self.estimator_context.execute_async_v3(stream_handle=handle)
|
||||
return ret
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
|
||||
Reference in New Issue
Block a user