fix flow matching training for zero shot inference

This commit is contained in:
lyuxiang.lx
2024-08-01 10:54:27 +08:00
parent 553244b8f2
commit 9504c3f88b

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import random
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
# get conditions # get conditions
conds = torch.zeros(feat.shape, device=token.device) conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2) conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h) mask = (~make_pad_mask(feat_len)).to(h)