mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
88
funasr_local/layers/time_warp.py
Normal file
88
funasr_local/layers/time_warp.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Time warp module."""
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.nets_utils import pad_list
|
||||
|
||||
DEFAULT_TIME_WARP_MODE = "bicubic"
|
||||
|
||||
|
||||
def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
|
||||
"""Time warping using torch.interpolate.
|
||||
|
||||
Args:
|
||||
x: (Batch, Time, Freq)
|
||||
window: time warp parameter
|
||||
mode: Interpolate mode
|
||||
"""
|
||||
|
||||
# bicubic supports 4D or more dimension tensor
|
||||
org_size = x.size()
|
||||
if x.dim() == 3:
|
||||
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
|
||||
x = x[:, None]
|
||||
|
||||
t = x.shape[2]
|
||||
if t - window <= window:
|
||||
return x.view(*org_size)
|
||||
|
||||
center = torch.randint(window, t - window, (1,))[0]
|
||||
warped = torch.randint(center - window, center + window, (1,))[0] + 1
|
||||
|
||||
# left: (Batch, Channel, warped, Freq)
|
||||
# right: (Batch, Channel, time - warped, Freq)
|
||||
left = torch.nn.functional.interpolate(
|
||||
x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
|
||||
)
|
||||
right = torch.nn.functional.interpolate(
|
||||
x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
|
||||
)
|
||||
|
||||
if x.requires_grad:
|
||||
x = torch.cat([left, right], dim=-2)
|
||||
else:
|
||||
x[:, :, :warped] = left
|
||||
x[:, :, warped:] = right
|
||||
|
||||
return x.view(*org_size)
|
||||
|
||||
|
||||
class TimeWarp(torch.nn.Module):
|
||||
"""Time warping using torch.interpolate.
|
||||
|
||||
Args:
|
||||
window: time warp parameter
|
||||
mode: Interpolate mode
|
||||
"""
|
||||
|
||||
def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
|
||||
super().__init__()
|
||||
self.window = window
|
||||
self.mode = mode
|
||||
|
||||
def extra_repr(self):
|
||||
return f"window={self.window}, mode={self.mode}"
|
||||
|
||||
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x: (Batch, Time, Freq)
|
||||
x_lengths: (Batch,)
|
||||
"""
|
||||
|
||||
if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
|
||||
# Note that applying same warping for each sample
|
||||
y = time_warp(x, window=self.window, mode=self.mode)
|
||||
else:
|
||||
# FIXME(kamo): I have no idea to batchify Timewarp
|
||||
ys = []
|
||||
for i in range(x.size(0)):
|
||||
_y = time_warp(
|
||||
x[i][None, : x_lengths[i]],
|
||||
window=self.window,
|
||||
mode=self.mode,
|
||||
)[0]
|
||||
ys.append(_y)
|
||||
y = pad_list(ys, 0.0)
|
||||
|
||||
return y, x_lengths
|
||||
Reference in New Issue
Block a user