mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix flow matching training for zero shot inference
This commit is contained in:
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
|
|
||||||
# get conditions
|
# get conditions
|
||||||
conds = torch.zeros(feat.shape, device=token.device)
|
conds = torch.zeros(feat.shape, device=token.device)
|
||||||
|
for i, j in enumerate(feat_len):
|
||||||
|
if random.random() < 0.5:
|
||||||
|
continue
|
||||||
|
index = random.randint(0, int(0.3 * j))
|
||||||
|
conds[i, :index] = feat[i, :index]
|
||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
mask = (~make_pad_mask(feat_len)).to(h)
|
mask = (~make_pad_mask(feat_len)).to(h)
|
||||||
|
|||||||
Reference in New Issue
Block a user