mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
52
funasr_local/modules/mask.py
Normal file
52
funasr_local/modules/mask.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Mask module."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def subsequent_mask(size, device="cpu", dtype=torch.bool):
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
:param int size: size of mask
|
||||
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=dtype)
|
||||
return torch.tril(ret, out=ret)
|
||||
|
||||
|
||||
def target_mask(ys_in_pad, ignore_id):
|
||||
"""Create mask for decoder self-attention.
|
||||
|
||||
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
|
||||
:param int ignore_id: index of padding
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor (B, Lmax, Lmax)
|
||||
"""
|
||||
ys_mask = ys_in_pad != ignore_id
|
||||
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
|
||||
return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool):
|
||||
"""Create mask for decoder self-attention.
|
||||
|
||||
:param int size: size of mask
|
||||
:param int vad_pos: index of vad index
|
||||
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor (B, Lmax, Lmax)
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=dtype)
|
||||
if vad_pos <= 0 or vad_pos >= size:
|
||||
return ret
|
||||
sub_corner = torch.zeros(
|
||||
vad_pos - 1, size - vad_pos, device=device, dtype=dtype)
|
||||
ret[0:vad_pos - 1, vad_pos:] = sub_corner
|
||||
return ret
|
||||
Reference in New Issue
Block a user