Adding saving phones while getting durations from matcha

This commit is contained in:
Shivam Mehta
2024-03-02 12:47:08 +00:00
parent ad76016916
commit 294c6b1327
9 changed files with 126 additions and 12 deletions

View File

@@ -126,7 +126,7 @@ class FlowMatchingDurationPrediction(nn.Module):
self.n_steps = params.n_steps
@torch.inference_mode()
def forward(self, enc_outputs, mask, n_timesteps=None, temperature=1):
def forward(self, enc_outputs, mask, n_timesteps=500, temperature=1):
"""Forward diffusion
Args: