From 9504c3f88b63335e6055fe70780a3d2ae3f81835 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 1 Aug 2024 10:54:27 +0800 Subject: [PATCH] fix flow matching training for zero shot inference --- cosyvoice/flow/flow.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 009160a..90a45b4 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from typing import Dict, Optional import torch import torch.nn as nn @@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module): # get conditions conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] conds = conds.transpose(1, 2) mask = (~make_pad_mask(feat_len)).to(h)