Merge remote-tracking branch 'origin/inference_streaming' into inference_streaming

This commit is contained in:
禾息
2024-09-03 11:13:25 +08:00
15 changed files with 754 additions and 6 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)

View File

@@ -82,10 +82,10 @@ class ConditionalCFM(BASECFM):
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.forward_estimator(
cfg_dphi_dt = self.estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,