diff --git a/tuning/utils.py b/tuning/utils.py index c4d58a5..46ee40c 100644 --- a/tuning/utils.py +++ b/tuning/utils.py @@ -240,6 +240,7 @@ def train(config, loss = criterion(stacked, targets) loss = (loss * masks).mean() + optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss.item(), masks.numel())