mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import onnxruntime
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
@@ -88,15 +89,25 @@ class ConditionalCFM(BASECFM):
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
if self.inference_cfg_rate > 0:
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
if self.inference_cfg_rate > 0:
|
||||
x_in = torch.concat([x, x], dim=0)
|
||||
mask_in = torch.concat([mask, mask], dim=0)
|
||||
mu_in = torch.concat([mu, torch.zeros_like(mu).to(x.device)], dim=0)
|
||||
t_in = torch.concat([t, t], dim=0)
|
||||
spks_in = torch.concat([spks, torch.zeros_like(spks).to(x.device)], dim=0) if spks is not None else None
|
||||
cond_in = torch.concat([cond, torch.zeros_like(cond).to(x.device)], dim=0) if cond is not None else None
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
else:
|
||||
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
||||
dphi_dt = self.forward_estimator(
|
||||
@@ -114,22 +125,53 @@ class ConditionalCFM(BASECFM):
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
|
||||
return sol[-1]
|
||||
return sol[-1].float()
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
else:
|
||||
elif isinstance(self.estimator, onnxruntime.InferenceSession):
|
||||
ort_inputs = {
|
||||
'x': x.cpu().numpy(),
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy()
|
||||
'spk': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy(),
|
||||
'mask_rand': torch.randn(1, 1, 1).numpy()
|
||||
}
|
||||
output = self.estimator.run(None, ort_inputs)[0]
|
||||
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
if not mask.is_contiguous():
|
||||
mask = mask.contiguous()
|
||||
if not mu.is_contiguous():
|
||||
mu = mu.contiguous()
|
||||
if not t.is_contiguous():
|
||||
t = t.contiguous()
|
||||
if not spks.is_contiguous():
|
||||
spks = spks.contiguous()
|
||||
if not cond.is_contiguous():
|
||||
cond = cond.contiguous()
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spk', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask_rand', (1, 1, 1))
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.data_ptr(),
|
||||
mask.data_ptr(),
|
||||
mu.data_ptr(),
|
||||
t.data_ptr(),
|
||||
spks.data_ptr(),
|
||||
cond.data_ptr(),
|
||||
torch.randn(1, 1, 1).to(x.device).data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
@@ -199,7 +241,8 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
"""
|
||||
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
||||
z[:] = 0
|
||||
if self.sp16 is True:
|
||||
z = z.half()
|
||||
# fix prompt and overlap part mu and z
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
|
||||
Reference in New Issue
Block a user