mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update
This commit is contained in:
@@ -111,6 +111,10 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
flow_cache):
|
||||
if self.fp16 is True:
|
||||
prompt_feat = prompt_feat.half()
|
||||
embedding = embedding.half()
|
||||
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
@@ -129,7 +133,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
@@ -145,7 +149,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat, flow_cache
|
||||
return feat.float(), flow_cache
|
||||
|
||||
|
||||
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
@@ -196,6 +200,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
finalize):
|
||||
if self.fp16 is True:
|
||||
prompt_feat = prompt_feat.half()
|
||||
embedding = embedding.half()
|
||||
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
@@ -214,7 +222,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
@@ -228,4 +236,4 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat, None
|
||||
return feat.float(), None
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import onnxruntime
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
@@ -52,7 +51,7 @@ class ConditionalCFM(BASECFM):
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
z = torch.randn_like(mu) * temperature
|
||||
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
||||
cache_size = flow_cache.shape[2]
|
||||
# fix prompt and overlap part mu and z
|
||||
if cache_size != 0:
|
||||
@@ -89,36 +88,29 @@ class ConditionalCFM(BASECFM):
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
if self.inference_cfg_rate > 0:
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
if self.inference_cfg_rate > 0:
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
else:
|
||||
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
dphi_dt = self.forward_estimator(
|
||||
x_in, mask_in,
|
||||
mu_in, t_in,
|
||||
spks_in,
|
||||
cond_in
|
||||
)
|
||||
if self.inference_cfg_rate > 0:
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
@@ -130,17 +122,6 @@ class ConditionalCFM(BASECFM):
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
elif isinstance(self.estimator, onnxruntime.InferenceSession):
|
||||
ort_inputs = {
|
||||
'x': x.cpu().numpy(),
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy()
|
||||
}
|
||||
output = self.estimator.run(None, ort_inputs)[0]
|
||||
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
@@ -225,9 +206,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
||||
if self.fp16 is True:
|
||||
z = z.half()
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||
# fix prompt and overlap part mu and z
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
|
||||
Reference in New Issue
Block a user