This commit is contained in:
lyuxiang.lx
2024-12-12 16:46:28 +08:00
parent 2345ce6be2
commit c693039d14
6 changed files with 145 additions and 71 deletions

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import onnxruntime
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
@@ -88,15 +89,25 @@ class ConditionalCFM(BASECFM):
# Or in future might add like a return_all_steps flag
sol = []
if self.inference_cfg_rate > 0:
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
else:
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
x_in = torch.concat([x, x], dim=0)
mask_in = torch.concat([mask, mask], dim=0)
mu_in = torch.concat([mu, torch.zeros_like(mu).to(x.device)], dim=0)
t_in = torch.concat([t, t], dim=0)
spks_in = torch.concat([spks, torch.zeros_like(spks).to(x.device)], dim=0) if spks is not None else None
cond_in = torch.concat([cond, torch.zeros_like(cond).to(x.device)], dim=0) if cond is not None else None
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
else:
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
dphi_dt = self.forward_estimator(
@@ -114,22 +125,53 @@ class ConditionalCFM(BASECFM):
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
return sol[-1].float()
def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
elif isinstance(self.estimator, onnxruntime.InferenceSession):
ort_inputs = {
'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(),
't': t.cpu().numpy(),
'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy()
'spk': spks.cpu().numpy(),
'cond': cond.cpu().numpy(),
'mask_rand': torch.randn(1, 1, 1).numpy()
}
output = self.estimator.run(None, ort_inputs)[0]
return torch.tensor(output, dtype=x.dtype, device=x.device)
else:
if not x.is_contiguous():
x = x.contiguous()
if not mask.is_contiguous():
mask = mask.contiguous()
if not mu.is_contiguous():
mu = mu.contiguous()
if not t.is_contiguous():
t = t.contiguous()
if not spks.is_contiguous():
spks = spks.contiguous()
if not cond.is_contiguous():
cond = cond.contiguous()
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spk', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask_rand', (1, 1, 1))
# run trt engine
self.estimator.execute_v2([x.data_ptr(),
mask.data_ptr(),
mu.data_ptr(),
t.data_ptr(),
spks.data_ptr(),
cond.data_ptr(),
torch.randn(1, 1, 1).to(x.device).data_ptr(),
x.data_ptr()])
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
@@ -199,7 +241,8 @@ class CausalConditionalCFM(ConditionalCFM):
"""
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
z[:] = 0
if self.sp16 is True:
z = z.half()
# fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':