mirror of
https://github.com/shivammehta25/Matcha-TTS.git
synced 2026-02-04 17:59:19 +08:00
23 lines
646 B
Python
23 lines
646 B
Python
import numpy as np
|
|
import torch
|
|
|
|
from matcha.utils.monotonic_align.core import maximum_path_c
|
|
|
|
|
|
def maximum_path(value, mask):
|
|
"""Cython optimised version.
|
|
value: [b, t_x, t_y]
|
|
mask: [b, t_x, t_y]
|
|
"""
|
|
value = value * mask
|
|
device = value.device
|
|
dtype = value.dtype
|
|
value = value.data.cpu().numpy().astype(np.float32)
|
|
path = np.zeros_like(value).astype(np.int32)
|
|
mask = mask.data.cpu().numpy()
|
|
|
|
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
|
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
|
maximum_path_c(path, value, t_x_max, t_y_max)
|
|
return torch.from_numpy(path).to(device=device, dtype=dtype)
|