diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 009160a..90a45b4 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from typing import Dict, Optional import torch import torch.nn as nn @@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module): # get conditions conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] conds = conds.transpose(1, 2) mask = (~make_pad_mask(feat_len)).to(h)