mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/modules/__init__.py
Normal file
0
funasr_local/modules/__init__.py
Normal file
31
funasr_local/modules/add_sos_eos.py
Normal file
31
funasr_local/modules/add_sos_eos.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Unility functions for Transformer."""
|
||||
|
||||
import torch
|
||||
from funasr_local.modules.nets_utils import pad_list
|
||||
|
||||
|
||||
def add_sos_eos(ys_pad, sos, eos, ignore_id):
|
||||
"""Add <sos> and <eos> labels.
|
||||
|
||||
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
|
||||
:param int sos: index of <sos>
|
||||
:param int eos: index of <eos>
|
||||
:param int ignore_id: index of padding
|
||||
:return: padded tensor (B, Lmax)
|
||||
:rtype: torch.Tensor
|
||||
:return: padded tensor (B, Lmax)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
_sos = ys_pad.new([sos])
|
||||
_eos = ys_pad.new([eos])
|
||||
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
||||
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
|
||||
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
|
||||
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
|
||||
961
funasr_local/modules/attention.py
Normal file
961
funasr_local/modules/attention.py
Normal file
@@ -0,0 +1,961 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Multi-Head Attention layer definition."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Optional, Tuple
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding (old version).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
self.zero_triu = zero_triu
|
||||
# linear transformation for positional encoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x):
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, head, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)
|
||||
|
||||
if self.zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)))
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time1)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding (new implementation).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
self.zero_triu = zero_triu
|
||||
# linear transformation for positional encoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x):
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)[
|
||||
:, :, :, : x.size(-1) // 2 + 1
|
||||
] # only keep the positions from 0 to time2
|
||||
|
||||
if self.zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
pos_emb (torch.Tensor): Positional embedding tensor
|
||||
(#batch, 2*time1-1, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class MultiHeadedAttentionSANM(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttentionSANM, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
# self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
# self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
# self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
|
||||
# padding
|
||||
left_padding = (kernel_size - 1) // 2
|
||||
if sanm_shfit > 0:
|
||||
left_padding = left_padding + sanm_shfit
|
||||
right_padding = kernel_size - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
|
||||
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
|
||||
b, t, d = inputs.size()
|
||||
if mask is not None:
|
||||
mask = torch.reshape(mask, (b, -1, 1))
|
||||
if mask_shfit_chunk is not None:
|
||||
mask = mask * mask_shfit_chunk
|
||||
inputs = inputs * mask
|
||||
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x += inputs
|
||||
x = self.dropout(x)
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
return x
|
||||
|
||||
def forward_qkv(self, x):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
b, t, d = x.size()
|
||||
q_k_v = self.linear_q_k_v(x)
|
||||
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
||||
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q_h, k_h, v_h, v
|
||||
|
||||
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
if mask_att_chunk_encoder is not None:
|
||||
mask = mask * mask_att_chunk_encoder
|
||||
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
class MultiHeadedAttentionSANMDecoder(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttentionSANMDecoder, self).__init__()
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.fsmn_block = nn.Conv1d(n_feat, n_feat,
|
||||
kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
|
||||
# padding
|
||||
# padding
|
||||
left_padding = (kernel_size - 1) // 2
|
||||
if sanm_shfit > 0:
|
||||
left_padding = left_padding + sanm_shfit
|
||||
right_padding = kernel_size - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
|
||||
'''
|
||||
:param x: (#batch, time1, size).
|
||||
:param mask: Mask tensor (#batch, 1, time)
|
||||
:return:
|
||||
'''
|
||||
# print("in fsmn, inputs", inputs.size())
|
||||
b, t, d = inputs.size()
|
||||
# logging.info(
|
||||
# "mask: {}".format(mask.size()))
|
||||
if mask is not None:
|
||||
mask = torch.reshape(mask, (b ,-1, 1))
|
||||
# logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
|
||||
if mask_shfit_chunk is not None:
|
||||
# logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
|
||||
mask = mask * mask_shfit_chunk
|
||||
# logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
|
||||
# print("in fsmn, mask", mask.size())
|
||||
# print("in fsmn, inputs", inputs.size())
|
||||
inputs = inputs * mask
|
||||
|
||||
x = inputs.transpose(1, 2)
|
||||
b, d, t = x.size()
|
||||
if cache is None:
|
||||
# print("in fsmn, cache is None, x", x.size())
|
||||
|
||||
x = self.pad_fn(x)
|
||||
if not self.training:
|
||||
cache = x
|
||||
else:
|
||||
# print("in fsmn, cache is not None, x", x.size())
|
||||
# x = torch.cat((x, cache), dim=2)[:, :, :-1]
|
||||
# if t < self.kernel_size:
|
||||
# x = self.pad_fn(x)
|
||||
x = torch.cat((cache[:, :, 1:], x), dim=2)
|
||||
x = x[:, :, -(self.kernel_size+t-1):]
|
||||
# print("in fsmn, cache is not None, x_cat", x.size())
|
||||
cache = x
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
# print("in fsmn, fsmn_out", x.size())
|
||||
if x.size(1) != inputs.size(1):
|
||||
inputs = inputs[:, -1, :]
|
||||
|
||||
x = x + inputs
|
||||
x = self.dropout(x)
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
return x, cache
|
||||
|
||||
class MultiHeadedAttentionCrossAtt(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttentionCrossAtt, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
# self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
# self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, x, memory):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
|
||||
# print("in forward_qkv, x", x.size())
|
||||
b = x.size(0)
|
||||
q = self.linear_q(x)
|
||||
q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
k_v = self.linear_k_v(memory)
|
||||
k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
|
||||
k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
|
||||
return q_h, k_h, v_h
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
# logging.info(
|
||||
# "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, x, memory, memory_mask):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q_h, k_h, v_h = self.forward_qkv(x, memory)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
return self.forward_attention(v_h, scores, memory_mask)
|
||||
|
||||
|
||||
class MultiHeadSelfAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, in_feat, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadSelfAttention, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, x):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
b, t, d = x.size()
|
||||
q_k_v = self.linear_q_k_v(x)
|
||||
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
||||
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q_h, k_h, v_h, v
|
||||
|
||||
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
if mask_att_chunk_encoder is not None:
|
||||
mask = mask * mask_att_chunk_encoder
|
||||
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, x, mask, mask_att_chunk_encoder=None):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
||||
return att_outs
|
||||
|
||||
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
|
||||
"""RelPositionMultiHeadedAttention definition.
|
||||
Args:
|
||||
num_heads: Number of attention heads.
|
||||
embed_size: Embedding size.
|
||||
dropout_rate: Dropout rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
embed_size: int,
|
||||
dropout_rate: float = 0.0,
|
||||
simplified_attention_score: bool = False,
|
||||
) -> None:
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super().__init__()
|
||||
|
||||
self.d_k = embed_size // num_heads
|
||||
self.num_heads = num_heads
|
||||
|
||||
assert self.d_k * num_heads == embed_size, (
|
||||
"embed_size (%d) must be divisible by num_heads (%d)",
|
||||
(embed_size, num_heads),
|
||||
)
|
||||
|
||||
self.linear_q = torch.nn.Linear(embed_size, embed_size)
|
||||
self.linear_k = torch.nn.Linear(embed_size, embed_size)
|
||||
self.linear_v = torch.nn.Linear(embed_size, embed_size)
|
||||
|
||||
self.linear_out = torch.nn.Linear(embed_size, embed_size)
|
||||
|
||||
if simplified_attention_score:
|
||||
self.linear_pos = torch.nn.Linear(embed_size, num_heads)
|
||||
|
||||
self.compute_att_score = self.compute_simplified_attention_score
|
||||
else:
|
||||
self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
|
||||
|
||||
self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
|
||||
self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
self.compute_att_score = self.compute_attention_score
|
||||
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.attn = None
|
||||
|
||||
def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
Args:
|
||||
x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
|
||||
left_context: Number of frames in left context.
|
||||
Returns:
|
||||
x: Output sequence. (B, H, T_1, T_2)
|
||||
"""
|
||||
batch_size, n_heads, time1, n = x.shape
|
||||
time2 = time1 + left_context
|
||||
|
||||
batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
|
||||
|
||||
return x.as_strided(
|
||||
(batch_size, n_heads, time1, time2),
|
||||
(batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=(n_stride * (time1 - 1)),
|
||||
)
|
||||
|
||||
def compute_simplified_attention_score(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
left_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Simplified attention score computation.
|
||||
Reference: https://github.com/k2-fsa/icefall/pull/458
|
||||
Args:
|
||||
query: Transformed query tensor. (B, H, T_1, d_k)
|
||||
key: Transformed key tensor. (B, H, T_2, d_k)
|
||||
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||
left_context: Number of frames in left context.
|
||||
Returns:
|
||||
: Attention score. (B, H, T_1, T_2)
|
||||
"""
|
||||
pos_enc = self.linear_pos(pos_enc)
|
||||
|
||||
matrix_ac = torch.matmul(query, key.transpose(2, 3))
|
||||
|
||||
matrix_bd = self.rel_shift(
|
||||
pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
|
||||
left_context=left_context,
|
||||
)
|
||||
|
||||
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
|
||||
|
||||
def compute_attention_score(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
left_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Attention score computation.
|
||||
Args:
|
||||
query: Transformed query tensor. (B, H, T_1, d_k)
|
||||
key: Transformed key tensor. (B, H, T_2, d_k)
|
||||
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||
left_context: Number of frames in left context.
|
||||
Returns:
|
||||
: Attention score. (B, H, T_1, T_2)
|
||||
"""
|
||||
p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
||||
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
||||
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
|
||||
matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
|
||||
|
||||
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
|
||||
|
||||
def forward_qkv(
|
||||
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Transform query, key and value.
|
||||
Args:
|
||||
query: Query tensor. (B, T_1, size)
|
||||
key: Key tensor. (B, T_2, size)
|
||||
v: Value tensor. (B, T_2, size)
|
||||
Returns:
|
||||
q: Transformed query tensor. (B, H, T_1, d_k)
|
||||
k: Transformed key tensor. (B, H, T_2, d_k)
|
||||
v: Transformed value tensor. (B, H, T_2, d_k)
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
|
||||
q = (
|
||||
self.linear_q(query)
|
||||
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k = (
|
||||
self.linear_k(key)
|
||||
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
v = (
|
||||
self.linear_v(value)
|
||||
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(
|
||||
self,
|
||||
value: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute attention context vector.
|
||||
Args:
|
||||
value: Transformed value. (B, H, T_2, d_k)
|
||||
scores: Attention score. (B, H, T_1, T_2)
|
||||
mask: Source mask. (B, T_2)
|
||||
chunk_mask: Chunk mask. (T_1, T_1)
|
||||
Returns:
|
||||
attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
|
||||
"""
|
||||
batch_size = scores.size(0)
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
if chunk_mask is not None:
|
||||
mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
|
||||
scores = scores.masked_fill(mask, float("-inf"))
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
|
||||
|
||||
attn_output = self.dropout(self.attn)
|
||||
attn_output = torch.matmul(attn_output, value)
|
||||
|
||||
attn_output = self.linear_out(
|
||||
attn_output.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, -1, self.num_heads * self.d_k)
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Compute scaled dot product attention with rel. positional encoding.
|
||||
Args:
|
||||
query: Query tensor. (B, T_1, size)
|
||||
key: Key tensor. (B, T_2, size)
|
||||
value: Value tensor. (B, T_2, size)
|
||||
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||
mask: Source mask. (B, T_2)
|
||||
chunk_mask: Chunk mask. (T_1, T_1)
|
||||
left_context: Number of frames in left context.
|
||||
Returns:
|
||||
: Output tensor. (B, T_1, H * d_k)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
|
||||
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
|
||||
0
funasr_local/modules/beam_search/__init__.py
Normal file
0
funasr_local/modules/beam_search/__init__.py
Normal file
348
funasr_local/modules/beam_search/batch_beam_search.py
Normal file
348
funasr_local/modules/beam_search/batch_beam_search.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Parallel beam search module."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import NamedTuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search import BeamSearch
|
||||
from funasr_local.modules.beam_search.beam_search import Hypothesis
|
||||
|
||||
|
||||
class BatchHypothesis(NamedTuple):
|
||||
"""Batchfied/Vectorized hypothesis data type."""
|
||||
|
||||
yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
|
||||
score: torch.Tensor = torch.tensor([]) # (batch,)
|
||||
length: torch.Tensor = torch.tensor([]) # (batch,)
|
||||
scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
|
||||
states: Dict[str, Dict] = dict()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return a batch size."""
|
||||
return len(self.length)
|
||||
|
||||
|
||||
class BatchBeamSearch(BeamSearch):
|
||||
"""Batch beam search implementation."""
|
||||
|
||||
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
|
||||
"""Convert list to batch."""
|
||||
if len(hyps) == 0:
|
||||
return BatchHypothesis()
|
||||
return BatchHypothesis(
|
||||
yseq=pad_sequence(
|
||||
[h.yseq for h in hyps], batch_first=True, padding_value=self.eos
|
||||
),
|
||||
length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
|
||||
score=torch.tensor([h.score for h in hyps]),
|
||||
scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
|
||||
states={k: [h.states[k] for h in hyps] for k in self.scorers},
|
||||
)
|
||||
|
||||
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
|
||||
return BatchHypothesis(
|
||||
yseq=hyps.yseq[ids],
|
||||
score=hyps.score[ids],
|
||||
length=hyps.length[ids],
|
||||
scores={k: v[ids] for k, v in hyps.scores.items()},
|
||||
states={
|
||||
k: [self.scorers[k].select_state(v, i) for i in ids]
|
||||
for k, v in hyps.states.items()
|
||||
},
|
||||
)
|
||||
|
||||
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
|
||||
return Hypothesis(
|
||||
yseq=hyps.yseq[i, : hyps.length[i]],
|
||||
score=hyps.score[i],
|
||||
scores={k: v[i] for k, v in hyps.scores.items()},
|
||||
states={
|
||||
k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
|
||||
},
|
||||
)
|
||||
|
||||
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
|
||||
"""Revert batch to list."""
|
||||
return [
|
||||
Hypothesis(
|
||||
yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
|
||||
score=batch_hyps.score[i],
|
||||
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
|
||||
states={
|
||||
k: v.select_state(batch_hyps.states[k], i)
|
||||
for k, v in self.scorers.items()
|
||||
},
|
||||
)
|
||||
for i in range(len(batch_hyps.length))
|
||||
]
|
||||
|
||||
def batch_beam(
|
||||
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Batch-compute topk full token ids and partial token ids.
|
||||
|
||||
Args:
|
||||
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
||||
Its shape is `(n_beam, self.vocab_size)`.
|
||||
ids (torch.Tensor): The partial token ids to compute topk.
|
||||
Its shape is `(n_beam, self.pre_beam_size)`.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
The topk full (prev_hyp, new_token) ids
|
||||
and partial (prev_hyp, new_token) ids.
|
||||
Their shapes are all `(self.beam_size,)`
|
||||
|
||||
"""
|
||||
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
|
||||
# Because of the flatten above, `top_ids` is organized as:
|
||||
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
|
||||
# where V is `self.n_vocab` and K is `self.beam_size`
|
||||
prev_hyp_ids = top_ids // self.n_vocab
|
||||
new_token_ids = top_ids % self.n_vocab
|
||||
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
|
||||
|
||||
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
|
||||
"""Get an initial hypothesis data.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoder output feature
|
||||
|
||||
Returns:
|
||||
Hypothesis: The initial hypothesis.
|
||||
|
||||
"""
|
||||
init_states = dict()
|
||||
init_scores = dict()
|
||||
for k, d in self.scorers.items():
|
||||
init_states[k] = d.batch_init_state(x)
|
||||
init_scores[k] = 0.0
|
||||
return self.batchfy(
|
||||
[
|
||||
Hypothesis(
|
||||
score=0.0,
|
||||
scores=init_scores,
|
||||
states=init_states,
|
||||
yseq=torch.tensor([self.sos], device=x.device),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def score_full(
|
||||
self, hyp: BatchHypothesis, x: torch.Tensor
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.full_scorers.items():
|
||||
scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
|
||||
return scores, states
|
||||
|
||||
def score_partial(
|
||||
self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
ids (torch.Tensor): 2D tensor of new partial tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.part_scorers.items():
|
||||
scores[k], states[k] = d.batch_score_partial(
|
||||
hyp.yseq, ids, hyp.states[k], x
|
||||
)
|
||||
return scores, states
|
||||
|
||||
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
||||
"""Merge states for new hypothesis.
|
||||
|
||||
Args:
|
||||
states: states of `self.full_scorers`
|
||||
part_states: states of `self.part_scorers`
|
||||
part_idx (int): The new token id for `part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are states of the scorers.
|
||||
|
||||
"""
|
||||
new_states = dict()
|
||||
for k, v in states.items():
|
||||
new_states[k] = v
|
||||
for k, v in part_states.items():
|
||||
new_states[k] = v
|
||||
return new_states
|
||||
|
||||
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
|
||||
"""Search new tokens for running hypotheses and encoded speech x.
|
||||
|
||||
Args:
|
||||
running_hyps (BatchHypothesis): Running hypotheses on beam
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
|
||||
Returns:
|
||||
BatchHypothesis: Best sorted hypotheses
|
||||
|
||||
"""
|
||||
n_batch = len(running_hyps)
|
||||
part_ids = None # no pre-beam
|
||||
# batch scoring
|
||||
weighted_scores = torch.zeros(
|
||||
n_batch, self.n_vocab, dtype=x.dtype, device=x.device
|
||||
)
|
||||
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
|
||||
for k in self.full_scorers:
|
||||
weighted_scores += self.weights[k] * scores[k]
|
||||
# partial scoring
|
||||
if self.do_pre_beam:
|
||||
pre_beam_scores = (
|
||||
weighted_scores
|
||||
if self.pre_beam_score_key == "full"
|
||||
else scores[self.pre_beam_score_key]
|
||||
)
|
||||
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
|
||||
# NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
|
||||
# full-size score matrices, which has non-zero scores for part_ids and zeros
|
||||
# for others.
|
||||
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
|
||||
for k in self.part_scorers:
|
||||
weighted_scores += self.weights[k] * part_scores[k]
|
||||
# add previous hyp scores
|
||||
weighted_scores += running_hyps.score.to(
|
||||
dtype=x.dtype, device=x.device
|
||||
).unsqueeze(1)
|
||||
|
||||
# TODO(karita): do not use list. use batch instead
|
||||
# see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
|
||||
# update hyps
|
||||
best_hyps = []
|
||||
prev_hyps = self.unbatchfy(running_hyps)
|
||||
for (
|
||||
full_prev_hyp_id,
|
||||
full_new_token_id,
|
||||
part_prev_hyp_id,
|
||||
part_new_token_id,
|
||||
) in zip(*self.batch_beam(weighted_scores, part_ids)):
|
||||
prev_hyp = prev_hyps[full_prev_hyp_id]
|
||||
best_hyps.append(
|
||||
Hypothesis(
|
||||
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
|
||||
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
|
||||
scores=self.merge_scores(
|
||||
prev_hyp.scores,
|
||||
{k: v[full_prev_hyp_id] for k, v in scores.items()},
|
||||
full_new_token_id,
|
||||
{k: v[part_prev_hyp_id] for k, v in part_scores.items()},
|
||||
part_new_token_id,
|
||||
),
|
||||
states=self.merge_states(
|
||||
{
|
||||
k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
|
||||
for k, v in states.items()
|
||||
},
|
||||
{
|
||||
k: self.part_scorers[k].select_state(
|
||||
v, part_prev_hyp_id, part_new_token_id
|
||||
)
|
||||
for k, v in part_states.items()
|
||||
},
|
||||
part_new_token_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
return self.batchfy(best_hyps)
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
i: int,
|
||||
maxlen: int,
|
||||
maxlenratio: float,
|
||||
running_hyps: BatchHypothesis,
|
||||
ended_hyps: List[Hypothesis],
|
||||
) -> BatchHypothesis:
|
||||
"""Perform post-processing of beam search iterations.
|
||||
|
||||
Args:
|
||||
i (int): The length of hypothesis tokens.
|
||||
maxlen (int): The maximum length of tokens in beam search.
|
||||
maxlenratio (int): The maximum length ratio in beam search.
|
||||
running_hyps (BatchHypothesis): The running hypotheses in beam search.
|
||||
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
||||
|
||||
Returns:
|
||||
BatchHypothesis: The new running hypotheses.
|
||||
|
||||
"""
|
||||
n_batch = running_hyps.yseq.shape[0]
|
||||
logging.debug(f"the number of running hypothes: {n_batch}")
|
||||
if self.token_list is not None:
|
||||
logging.debug(
|
||||
"best hypo: "
|
||||
+ "".join(
|
||||
[
|
||||
self.token_list[x]
|
||||
for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
|
||||
]
|
||||
)
|
||||
)
|
||||
# add eos in the final loop to avoid that there are no ended hyps
|
||||
if i == maxlen - 1:
|
||||
logging.info("adding <eos> in the last position in the loop")
|
||||
yseq_eos = torch.cat(
|
||||
(
|
||||
running_hyps.yseq,
|
||||
torch.full(
|
||||
(n_batch, 1),
|
||||
self.eos,
|
||||
device=running_hyps.yseq.device,
|
||||
dtype=torch.int64,
|
||||
),
|
||||
),
|
||||
1,
|
||||
)
|
||||
running_hyps.yseq.resize_as_(yseq_eos)
|
||||
running_hyps.yseq[:] = yseq_eos
|
||||
running_hyps.length[:] = yseq_eos.shape[1]
|
||||
|
||||
# add ended hypotheses to a final list, and removed them from current hypotheses
|
||||
# (this will be a probmlem, number of hyps < beam)
|
||||
is_eos = (
|
||||
running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
|
||||
== self.eos
|
||||
)
|
||||
for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
|
||||
hyp = self._select(running_hyps, b)
|
||||
ended_hyps.append(hyp)
|
||||
remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1)
|
||||
return self._batch_select(running_hyps, remained_ids)
|
||||
270
funasr_local/modules/beam_search/batch_beam_search_online_sim.py
Normal file
270
funasr_local/modules/beam_search/batch_beam_search_online_sim.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Parallel beam search module for online simulation."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.beam_search.batch_beam_search import BatchBeamSearch
|
||||
from funasr_local.modules.beam_search.beam_search import Hypothesis
|
||||
from funasr_local.models.e2e_asr_common import end_detect
|
||||
|
||||
|
||||
class BatchBeamSearchOnlineSim(BatchBeamSearch):
|
||||
"""Online beam search implementation.
|
||||
|
||||
This simulates streaming decoding.
|
||||
It requires encoded features of entire utterance and
|
||||
extracts block by block from it as it shoud be done
|
||||
in streaming processing.
|
||||
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
|
||||
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
|
||||
(https://arxiv.org/abs/2006.14941).
|
||||
"""
|
||||
|
||||
def set_streaming_config(self, asr_config: str):
|
||||
"""Set config file for streaming decoding.
|
||||
|
||||
Args:
|
||||
asr_config (str): The config file for asr training
|
||||
|
||||
"""
|
||||
train_config_file = Path(asr_config)
|
||||
self.block_size = None
|
||||
self.hop_size = None
|
||||
self.look_ahead = None
|
||||
config = None
|
||||
with train_config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
if "encoder_conf" in args.keys():
|
||||
if "block_size" in args["encoder_conf"].keys():
|
||||
self.block_size = args["encoder_conf"]["block_size"]
|
||||
if "hop_size" in args["encoder_conf"].keys():
|
||||
self.hop_size = args["encoder_conf"]["hop_size"]
|
||||
if "look_ahead" in args["encoder_conf"].keys():
|
||||
self.look_ahead = args["encoder_conf"]["look_ahead"]
|
||||
elif "config" in args.keys():
|
||||
config = args["config"]
|
||||
if config is None:
|
||||
logging.info(
|
||||
"Cannot find config file for streaming decoding: "
|
||||
+ "apply batch beam search instead."
|
||||
)
|
||||
return
|
||||
if (
|
||||
self.block_size is None or self.hop_size is None or self.look_ahead is None
|
||||
) and config is not None:
|
||||
config_file = Path(config)
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
if "encoder_conf" in args.keys():
|
||||
enc_args = args["encoder_conf"]
|
||||
if enc_args and "block_size" in enc_args:
|
||||
self.block_size = enc_args["block_size"]
|
||||
if enc_args and "hop_size" in enc_args:
|
||||
self.hop_size = enc_args["hop_size"]
|
||||
if enc_args and "look_ahead" in enc_args:
|
||||
self.look_ahead = enc_args["look_ahead"]
|
||||
|
||||
def set_block_size(self, block_size: int):
|
||||
"""Set block size for streaming decoding.
|
||||
|
||||
Args:
|
||||
block_size (int): The block size of encoder
|
||||
"""
|
||||
self.block_size = block_size
|
||||
|
||||
def set_hop_size(self, hop_size: int):
|
||||
"""Set hop size for streaming decoding.
|
||||
|
||||
Args:
|
||||
hop_size (int): The hop size of encoder
|
||||
"""
|
||||
self.hop_size = hop_size
|
||||
|
||||
def set_look_ahead(self, look_ahead: int):
|
||||
"""Set look ahead size for streaming decoding.
|
||||
|
||||
Args:
|
||||
look_ahead (int): The look ahead size of encoder
|
||||
"""
|
||||
self.look_ahead = look_ahead
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform beam search.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
maxlenratio (float): Input length ratio to obtain max output length.
|
||||
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||
to automatically find maximum hypothesis lengths
|
||||
minlenratio (float): Input length ratio to obtain min output length.
|
||||
|
||||
Returns:
|
||||
list[Hypothesis]: N-best decoding results
|
||||
|
||||
"""
|
||||
self.conservative = True # always true
|
||||
|
||||
if self.block_size and self.hop_size and self.look_ahead:
|
||||
cur_end_frame = int(self.block_size - self.look_ahead)
|
||||
else:
|
||||
cur_end_frame = x.shape[0]
|
||||
process_idx = 0
|
||||
if cur_end_frame < x.shape[0]:
|
||||
h = x.narrow(0, 0, cur_end_frame)
|
||||
else:
|
||||
h = x
|
||||
|
||||
# set length bounds
|
||||
if maxlenratio == 0:
|
||||
maxlen = x.shape[0]
|
||||
else:
|
||||
maxlen = max(1, int(maxlenratio * x.size(0)))
|
||||
minlen = int(minlenratio * x.size(0))
|
||||
logging.info("decoder input length: " + str(x.shape[0]))
|
||||
logging.info("max output length: " + str(maxlen))
|
||||
logging.info("min output length: " + str(minlen))
|
||||
|
||||
# main loop of prefix search
|
||||
running_hyps = self.init_hyp(h)
|
||||
prev_hyps = []
|
||||
ended_hyps = []
|
||||
prev_repeat = False
|
||||
|
||||
continue_decode = True
|
||||
|
||||
while continue_decode:
|
||||
move_to_next_block = False
|
||||
if cur_end_frame < x.shape[0]:
|
||||
h = x.narrow(0, 0, cur_end_frame)
|
||||
else:
|
||||
h = x
|
||||
|
||||
# extend states for ctc
|
||||
self.extend(h, running_hyps)
|
||||
|
||||
while process_idx < maxlen:
|
||||
logging.debug("position " + str(process_idx))
|
||||
best = self.search(running_hyps, h)
|
||||
|
||||
if process_idx == maxlen - 1:
|
||||
# end decoding
|
||||
running_hyps = self.post_process(
|
||||
process_idx, maxlen, maxlenratio, best, ended_hyps
|
||||
)
|
||||
n_batch = best.yseq.shape[0]
|
||||
local_ended_hyps = []
|
||||
is_local_eos = (
|
||||
best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
|
||||
)
|
||||
for i in range(is_local_eos.shape[0]):
|
||||
if is_local_eos[i]:
|
||||
hyp = self._select(best, i)
|
||||
local_ended_hyps.append(hyp)
|
||||
# NOTE(tsunoo): check repetitions here
|
||||
# This is a implicit implementation of
|
||||
# Eq (11) in https://arxiv.org/abs/2006.14941
|
||||
# A flag prev_repeat is used instead of using set
|
||||
elif (
|
||||
not prev_repeat
|
||||
and best.yseq[i, -1] in best.yseq[i, :-1]
|
||||
and cur_end_frame < x.shape[0]
|
||||
):
|
||||
move_to_next_block = True
|
||||
prev_repeat = True
|
||||
if maxlenratio == 0.0 and end_detect(
|
||||
[lh.asdict() for lh in local_ended_hyps], process_idx
|
||||
):
|
||||
logging.info(f"end detected at {process_idx}")
|
||||
continue_decode = False
|
||||
break
|
||||
if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
|
||||
move_to_next_block = True
|
||||
|
||||
if move_to_next_block:
|
||||
if (
|
||||
self.hop_size
|
||||
and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
|
||||
< x.shape[0]
|
||||
):
|
||||
cur_end_frame += int(self.hop_size)
|
||||
else:
|
||||
cur_end_frame = x.shape[0]
|
||||
logging.debug("Going to next block: %d", cur_end_frame)
|
||||
if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
|
||||
running_hyps = prev_hyps
|
||||
process_idx -= 1
|
||||
prev_hyps = []
|
||||
break
|
||||
|
||||
prev_repeat = False
|
||||
prev_hyps = running_hyps
|
||||
running_hyps = self.post_process(
|
||||
process_idx, maxlen, maxlenratio, best, ended_hyps
|
||||
)
|
||||
|
||||
if cur_end_frame >= x.shape[0]:
|
||||
for hyp in local_ended_hyps:
|
||||
ended_hyps.append(hyp)
|
||||
|
||||
if len(running_hyps) == 0:
|
||||
logging.info("no hypothesis. Finish decoding.")
|
||||
continue_decode = False
|
||||
break
|
||||
else:
|
||||
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
||||
# increment number
|
||||
process_idx += 1
|
||||
|
||||
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
||||
# check the number of hypotheses reaching to eos
|
||||
if len(nbest_hyps) == 0:
|
||||
logging.warning(
|
||||
"there is no N-best results, perform recognition "
|
||||
"again with smaller minlenratio."
|
||||
)
|
||||
return (
|
||||
[]
|
||||
if minlenratio < 0.1
|
||||
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
||||
)
|
||||
|
||||
# report the best result
|
||||
best = nbest_hyps[0]
|
||||
for k, v in best.scores.items():
|
||||
logging.info(
|
||||
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
||||
)
|
||||
logging.info(f"total log probability: {best.score:.2f}")
|
||||
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
||||
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
||||
if self.token_list is not None:
|
||||
logging.info(
|
||||
"best hypo: "
|
||||
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
||||
+ "\n"
|
||||
)
|
||||
return nbest_hyps
|
||||
|
||||
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
|
||||
"""Extend probabilities and states with more encoded chunks.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The extended encoder output feature
|
||||
hyps (Hypothesis): Current list of hypothesis
|
||||
|
||||
Returns:
|
||||
Hypothesis: The extended hypothesis
|
||||
|
||||
"""
|
||||
for k, d in self.scorers.items():
|
||||
if hasattr(d, "extend_prob"):
|
||||
d.extend_prob(x)
|
||||
if hasattr(d, "extend_state"):
|
||||
hyps.states[k] = d.extend_state(hyps.states[k])
|
||||
1400
funasr_local/modules/beam_search/beam_search.py
Normal file
1400
funasr_local/modules/beam_search/beam_search.py
Normal file
File diff suppressed because it is too large
Load Diff
704
funasr_local/modules/beam_search/beam_search_transducer.py
Normal file
704
funasr_local/modules/beam_search/beam_search_transducer.py
Normal file
@@ -0,0 +1,704 @@
|
||||
"""Search algorithms for Transducer models."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr_local.models.joint_net.joint_network import JointNetwork
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
"""Default hypothesis definition for Transducer search algorithms.
|
||||
|
||||
Args:
|
||||
score: Total log-probability.
|
||||
yseq: Label sequence as integer ID sequence.
|
||||
dec_state: RNNDecoder or StatelessDecoder state.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None) or None
|
||||
lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
|
||||
|
||||
"""
|
||||
|
||||
score: float
|
||||
yseq: List[int]
|
||||
dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None
|
||||
lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtendedHypothesis(Hypothesis):
|
||||
"""Extended hypothesis definition for NSC beam search and mAES.
|
||||
|
||||
Args:
|
||||
: Hypothesis dataclass arguments.
|
||||
dec_out: Decoder output sequence. (B, D_dec)
|
||||
lm_score: Log-probabilities of the LM for given label. (vocab_size)
|
||||
|
||||
"""
|
||||
|
||||
dec_out: torch.Tensor = None
|
||||
lm_score: torch.Tensor = None
|
||||
|
||||
|
||||
class BeamSearchTransducer:
|
||||
"""Beam search implementation for Transducer.
|
||||
|
||||
Args:
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint network module.
|
||||
beam_size: Size of the beam.
|
||||
lm: LM class.
|
||||
lm_weight: LM weight for soft fusion.
|
||||
search_type: Search algorithm to use during inference.
|
||||
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
|
||||
u_max: Maximum expected target sequence length. (ALSD)
|
||||
nstep: Number of maximum expansion steps at each time step. (mAES)
|
||||
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
|
||||
expansion_beta:
|
||||
Number of additional candidates for expanded hypotheses selection. (mAES)
|
||||
score_norm: Normalize final scores by length.
|
||||
nbest: Number of final hypothesis.
|
||||
streaming: Whether to perform chunk-by-chunk beam search.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
joint_network: JointNetwork,
|
||||
beam_size: int,
|
||||
lm: Optional[torch.nn.Module] = None,
|
||||
lm_weight: float = 0.1,
|
||||
search_type: str = "default",
|
||||
max_sym_exp: int = 3,
|
||||
u_max: int = 50,
|
||||
nstep: int = 2,
|
||||
expansion_gamma: float = 2.3,
|
||||
expansion_beta: int = 2,
|
||||
score_norm: bool = False,
|
||||
nbest: int = 1,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
"""Construct a BeamSearchTransducer object."""
|
||||
super().__init__()
|
||||
|
||||
self.decoder = decoder
|
||||
self.joint_network = joint_network
|
||||
|
||||
self.vocab_size = decoder.vocab_size
|
||||
|
||||
assert beam_size <= self.vocab_size, (
|
||||
"beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
|
||||
% (
|
||||
beam_size,
|
||||
self.vocab_size,
|
||||
)
|
||||
)
|
||||
self.beam_size = beam_size
|
||||
|
||||
if search_type == "default":
|
||||
self.search_algorithm = self.default_beam_search
|
||||
elif search_type == "tsd":
|
||||
assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
|
||||
max_sym_exp
|
||||
)
|
||||
self.max_sym_exp = max_sym_exp
|
||||
|
||||
self.search_algorithm = self.time_sync_decoding
|
||||
elif search_type == "alsd":
|
||||
assert not streaming, "ALSD is not available in streaming mode."
|
||||
|
||||
assert u_max >= 0, "u_max should be a positive integer, a portion of max_T."
|
||||
self.u_max = u_max
|
||||
|
||||
self.search_algorithm = self.align_length_sync_decoding
|
||||
elif search_type == "maes":
|
||||
assert self.vocab_size >= beam_size + expansion_beta, (
|
||||
"beam_size (%d) + expansion_beta (%d) "
|
||||
" should be smaller than or equal to vocab size (%d)."
|
||||
% (beam_size, expansion_beta, self.vocab_size)
|
||||
)
|
||||
self.max_candidates = beam_size + expansion_beta
|
||||
|
||||
self.nstep = nstep
|
||||
self.expansion_gamma = expansion_gamma
|
||||
|
||||
self.search_algorithm = self.modified_adaptive_expansion_search
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Specified search type (%s) is not supported." % search_type
|
||||
)
|
||||
|
||||
self.use_lm = lm is not None
|
||||
|
||||
if self.use_lm:
|
||||
assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported."
|
||||
|
||||
self.sos = self.vocab_size - 1
|
||||
|
||||
self.lm = lm
|
||||
self.lm_weight = lm_weight
|
||||
|
||||
self.score_norm = score_norm
|
||||
self.nbest = nbest
|
||||
|
||||
self.reset_inference_cache()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
is_final: bool = True,
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform beam search.
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D_enc)
|
||||
is_final: Whether enc_out is the final chunk of data.
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best decoding results
|
||||
|
||||
"""
|
||||
self.decoder.set_device(enc_out.device)
|
||||
|
||||
hyps = self.search_algorithm(enc_out)
|
||||
|
||||
if is_final:
|
||||
self.reset_inference_cache()
|
||||
|
||||
return self.sort_nbest(hyps)
|
||||
|
||||
self.search_cache = hyps
|
||||
|
||||
return hyps
|
||||
|
||||
def reset_inference_cache(self) -> None:
|
||||
"""Reset cache for decoder scoring and streaming."""
|
||||
self.decoder.score_cache = {}
|
||||
self.search_cache = None
|
||||
|
||||
def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
|
||||
"""Sort in-place hypotheses by score or score given sequence length.
|
||||
|
||||
Args:
|
||||
hyps: Hypothesis.
|
||||
|
||||
Return:
|
||||
hyps: Sorted hypothesis.
|
||||
|
||||
"""
|
||||
if self.score_norm:
|
||||
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
|
||||
else:
|
||||
hyps.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
return hyps[: self.nbest]
|
||||
|
||||
def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
|
||||
"""Recombine hypotheses with same label ID sequence.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
final: Recombined hypotheses.
|
||||
|
||||
"""
|
||||
final = {}
|
||||
|
||||
for hyp in hyps:
|
||||
str_yseq = "_".join(map(str, hyp.yseq))
|
||||
|
||||
if str_yseq in final:
|
||||
final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score)
|
||||
else:
|
||||
final[str_yseq] = hyp
|
||||
|
||||
return [*final.values()]
|
||||
|
||||
def select_k_expansions(
|
||||
self,
|
||||
hyps: List[ExtendedHypothesis],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_logp: torch.Tensor,
|
||||
) -> List[ExtendedHypothesis]:
|
||||
"""Return K hypotheses candidates for expansion from a list of hypothesis.
|
||||
|
||||
K candidates are selected according to the extended hypotheses probabilities
|
||||
and a prune-by-value method. Where K is equal to beam_size + beta.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
topk_idx: Indices of candidates hypothesis.
|
||||
topk_logp: Log-probabilities of candidates hypothesis.
|
||||
|
||||
Returns:
|
||||
k_expansions: Best K expansion hypotheses candidates.
|
||||
|
||||
"""
|
||||
k_expansions = []
|
||||
|
||||
for i, hyp in enumerate(hyps):
|
||||
hyp_i = [
|
||||
(int(k), hyp.score + float(v))
|
||||
for k, v in zip(topk_idx[i], topk_logp[i])
|
||||
]
|
||||
k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
|
||||
|
||||
k_expansions.append(
|
||||
sorted(
|
||||
filter(
|
||||
lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
|
||||
),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
return k_expansions
|
||||
|
||||
def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor:
|
||||
"""Make batch of inputs with left padding for LM scoring.
|
||||
|
||||
Args:
|
||||
hyps_seq: Hypothesis sequences.
|
||||
|
||||
Returns:
|
||||
: Padded batch of sequences.
|
||||
|
||||
"""
|
||||
max_len = max([len(h) for h in hyps_seq])
|
||||
|
||||
return torch.LongTensor(
|
||||
[[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq],
|
||||
device=self.decoder.device,
|
||||
)
|
||||
|
||||
def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Beam search implementation without prefix search.
|
||||
|
||||
Modified from https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
beam_k = min(self.beam_size, (self.vocab_size - 1))
|
||||
max_t = len(enc_out)
|
||||
|
||||
if self.search_cache is not None:
|
||||
kept_hyps = self.search_cache
|
||||
else:
|
||||
kept_hyps = [
|
||||
Hypothesis(
|
||||
score=0.0,
|
||||
yseq=[0],
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
for t in range(max_t):
|
||||
hyps = kept_hyps
|
||||
kept_hyps = []
|
||||
|
||||
while True:
|
||||
max_hyp = max(hyps, key=lambda x: x.score)
|
||||
hyps.remove(max_hyp)
|
||||
|
||||
label = torch.full(
|
||||
(1, 1),
|
||||
max_hyp.yseq[-1],
|
||||
dtype=torch.long,
|
||||
device=self.decoder.device,
|
||||
)
|
||||
dec_out, state = self.decoder.score(
|
||||
label,
|
||||
max_hyp.yseq,
|
||||
max_hyp.dec_state,
|
||||
)
|
||||
|
||||
logp = torch.log_softmax(
|
||||
self.joint_network(enc_out[t : t + 1, :], dec_out),
|
||||
dim=-1,
|
||||
).squeeze(0)
|
||||
top_k = logp[1:].topk(beam_k, dim=-1)
|
||||
|
||||
kept_hyps.append(
|
||||
Hypothesis(
|
||||
score=(max_hyp.score + float(logp[0:1])),
|
||||
yseq=max_hyp.yseq,
|
||||
dec_state=max_hyp.dec_state,
|
||||
lm_state=max_hyp.lm_state,
|
||||
)
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
lm_scores, lm_state = self.lm.score(
|
||||
torch.LongTensor(
|
||||
[self.sos] + max_hyp.yseq[1:], device=self.decoder.device
|
||||
),
|
||||
max_hyp.lm_state,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
lm_state = max_hyp.lm_state
|
||||
|
||||
for logp, k in zip(*top_k):
|
||||
score = max_hyp.score + float(logp)
|
||||
|
||||
if self.use_lm:
|
||||
score += self.lm_weight * lm_scores[k + 1]
|
||||
|
||||
hyps.append(
|
||||
Hypothesis(
|
||||
score=score,
|
||||
yseq=max_hyp.yseq + [int(k + 1)],
|
||||
dec_state=state,
|
||||
lm_state=lm_state,
|
||||
)
|
||||
)
|
||||
|
||||
hyps_max = float(max(hyps, key=lambda x: x.score).score)
|
||||
kept_most_prob = sorted(
|
||||
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
|
||||
key=lambda x: x.score,
|
||||
)
|
||||
if len(kept_most_prob) >= self.beam_size:
|
||||
kept_hyps = kept_most_prob
|
||||
break
|
||||
|
||||
return kept_hyps
|
||||
|
||||
def align_length_sync_decoding(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
) -> List[Hypothesis]:
|
||||
"""Alignment-length synchronous beam search implementation.
|
||||
|
||||
Based on https://ieeexplore.ieee.org/document/9053040
|
||||
|
||||
Args:
|
||||
h: Encoder output sequences. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
t_max = int(enc_out.size(0))
|
||||
u_max = min(self.u_max, (t_max - 1))
|
||||
|
||||
B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))]
|
||||
final = []
|
||||
|
||||
if self.use_lm:
|
||||
B[0].lm_state = self.lm.zero_state()
|
||||
|
||||
for i in range(t_max + u_max):
|
||||
A = []
|
||||
|
||||
B_ = []
|
||||
B_enc_out = []
|
||||
for hyp in B:
|
||||
u = len(hyp.yseq) - 1
|
||||
t = i - u
|
||||
|
||||
if t > (t_max - 1):
|
||||
continue
|
||||
|
||||
B_.append(hyp)
|
||||
B_enc_out.append((t, enc_out[t]))
|
||||
|
||||
if B_:
|
||||
beam_enc_out = torch.stack([b[1] for b in B_enc_out])
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(B_)
|
||||
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([b.yseq for b in B_]),
|
||||
[b.lm_state for b in B_],
|
||||
None,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(B_):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(beam_logp[i, 0])),
|
||||
yseq=hyp.yseq[:],
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
A.append(new_hyp)
|
||||
|
||||
if B_enc_out[i][0] == (t_max - 1):
|
||||
final.append(new_hyp)
|
||||
|
||||
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(logp)),
|
||||
yseq=(hyp.yseq[:] + [int(k)]),
|
||||
dec_state=self.decoder.select_state(beam_state, i),
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
|
||||
new_hyp.lm_state = beam_lm_states[i]
|
||||
|
||||
A.append(new_hyp)
|
||||
|
||||
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
B = self.recombine_hyps(B)
|
||||
|
||||
if final:
|
||||
return final
|
||||
|
||||
return B
|
||||
|
||||
def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Time synchronous beam search implementation.
|
||||
|
||||
Based on https://ieeexplore.ieee.org/document/9053040
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
if self.search_cache is not None:
|
||||
B = self.search_cache
|
||||
else:
|
||||
B = [
|
||||
Hypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
if self.use_lm:
|
||||
B[0].lm_state = self.lm.zero_state()
|
||||
|
||||
for enc_out_t in enc_out:
|
||||
A = []
|
||||
C = B
|
||||
|
||||
enc_out_t = enc_out_t.unsqueeze(0)
|
||||
|
||||
for v in range(self.max_sym_exp):
|
||||
D = []
|
||||
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(C)
|
||||
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(enc_out_t, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
|
||||
|
||||
seq_A = [h.yseq for h in A]
|
||||
|
||||
for i, hyp in enumerate(C):
|
||||
if hyp.yseq not in seq_A:
|
||||
A.append(
|
||||
Hypothesis(
|
||||
score=(hyp.score + float(beam_logp[i, 0])),
|
||||
yseq=hyp.yseq[:],
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dict_pos = seq_A.index(hyp.yseq)
|
||||
|
||||
A[dict_pos].score = np.logaddexp(
|
||||
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
|
||||
)
|
||||
|
||||
if v < (self.max_sym_exp - 1):
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([c.yseq for c in C]),
|
||||
[c.lm_state for c in C],
|
||||
None,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(C):
|
||||
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(logp)),
|
||||
yseq=(hyp.yseq + [int(k)]),
|
||||
dec_state=self.decoder.select_state(beam_state, i),
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
|
||||
new_hyp.lm_state = beam_lm_states[i]
|
||||
|
||||
D.append(new_hyp)
|
||||
|
||||
C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
|
||||
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
|
||||
return B
|
||||
|
||||
def modified_adaptive_expansion_search(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
) -> List[ExtendedHypothesis]:
|
||||
"""Modified version of Adaptive Expansion Search (mAES).
|
||||
|
||||
Based on AES (https://ieeexplore.ieee.org/document/9250505) and
|
||||
NSC (https://arxiv.org/abs/2201.05420).
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D_enc)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
if self.search_cache is not None:
|
||||
kept_hyps = self.search_cache
|
||||
else:
|
||||
init_tokens = [
|
||||
ExtendedHypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(
|
||||
init_tokens,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([h.yseq for h in init_tokens]),
|
||||
[h.lm_state for h in init_tokens],
|
||||
None,
|
||||
)
|
||||
|
||||
lm_state = beam_lm_states[0]
|
||||
lm_score = beam_lm_scores[0]
|
||||
else:
|
||||
lm_state = None
|
||||
lm_score = None
|
||||
|
||||
kept_hyps = [
|
||||
ExtendedHypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.select_state(beam_state, 0),
|
||||
dec_out=beam_dec_out[0],
|
||||
lm_state=lm_state,
|
||||
lm_score=lm_score,
|
||||
)
|
||||
]
|
||||
|
||||
for enc_out_t in enc_out:
|
||||
hyps = kept_hyps
|
||||
kept_hyps = []
|
||||
|
||||
beam_enc_out = enc_out_t.unsqueeze(0)
|
||||
|
||||
list_b = []
|
||||
for n in range(self.nstep):
|
||||
beam_dec_out = torch.stack([h.dec_out for h in hyps])
|
||||
|
||||
beam_logp, beam_idx = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
).topk(self.max_candidates, dim=-1)
|
||||
|
||||
k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp)
|
||||
|
||||
list_exp = []
|
||||
for i, hyp in enumerate(hyps):
|
||||
for k, new_score in k_expansions[i]:
|
||||
new_hyp = ExtendedHypothesis(
|
||||
yseq=hyp.yseq[:],
|
||||
score=new_score,
|
||||
dec_out=hyp.dec_out,
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
lm_score=hyp.lm_score,
|
||||
)
|
||||
|
||||
if k == 0:
|
||||
list_b.append(new_hyp)
|
||||
else:
|
||||
new_hyp.yseq.append(int(k))
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * float(hyp.lm_score[k])
|
||||
|
||||
list_exp.append(new_hyp)
|
||||
|
||||
if not list_exp:
|
||||
kept_hyps = sorted(
|
||||
self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True
|
||||
)[: self.beam_size]
|
||||
|
||||
break
|
||||
else:
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(
|
||||
list_exp,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([h.yseq for h in list_exp]),
|
||||
[h.lm_state for h in list_exp],
|
||||
None,
|
||||
)
|
||||
|
||||
if n < (self.nstep - 1):
|
||||
for i, hyp in enumerate(list_exp):
|
||||
hyp.dec_out = beam_dec_out[i]
|
||||
hyp.dec_state = self.decoder.select_state(beam_state, i)
|
||||
|
||||
if self.use_lm:
|
||||
hyp.lm_state = beam_lm_states[i]
|
||||
hyp.lm_score = beam_lm_scores[i]
|
||||
|
||||
hyps = list_exp[:]
|
||||
else:
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(list_exp):
|
||||
hyp.score += float(beam_logp[i, 0])
|
||||
|
||||
hyp.dec_out = beam_dec_out[i]
|
||||
hyp.dec_state = self.decoder.select_state(beam_state, i)
|
||||
|
||||
if self.use_lm:
|
||||
hyp.lm_state = beam_lm_states[i]
|
||||
hyp.lm_score = beam_lm_scores[i]
|
||||
|
||||
kept_hyps = sorted(
|
||||
self.recombine_hyps(list_b + list_exp),
|
||||
key=lambda x: x.score,
|
||||
reverse=True,
|
||||
)[: self.beam_size]
|
||||
|
||||
return kept_hyps
|
||||
0
funasr_local/modules/data2vec/__init__.py
Normal file
0
funasr_local/modules/data2vec/__init__.py
Normal file
147
funasr_local/modules/data2vec/data_utils.py
Normal file
147
funasr_local/modules/data2vec/data_utils.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
require_same_masks: bool = True,
|
||||
mask_dropout: float = 0.0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
||||
mask_dropout: randomly dropout this percentage of masks in each example
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = np.random.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = np.random.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = np.random.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
|
||||
mask_idc = np.asarray(
|
||||
[
|
||||
mask_idc[j] + offset
|
||||
for j in range(len(mask_idc))
|
||||
for offset in range(lengths[j])
|
||||
]
|
||||
)
|
||||
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len and require_same_masks:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
if mask_dropout > 0:
|
||||
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
|
||||
mask_idc = np.random.choice(
|
||||
mask_idc, len(mask_idc) - num_holes, replace=False
|
||||
)
|
||||
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
return mask
|
||||
132
funasr_local/modules/data2vec/ema_module.py
Normal file
132
funasr_local/modules/data2vec/ema_module.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Used for EMA tracking a given pytorch module. The user is responsible for calling step()
|
||||
and setting the appropriate decay
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EMAModule:
|
||||
"""Exponential Moving Average of Fairseq Models"""
|
||||
|
||||
def __init__(self, model, ema_decay=0.9999, ema_fp32=False, device=None, skip_keys=None):
|
||||
"""
|
||||
@param model model to initialize the EMA with
|
||||
@param config EMAConfig object with configuration like
|
||||
ema_decay, ema_update_freq, ema_fp32
|
||||
@param device If provided, copy EMA to this device (e.g. gpu).
|
||||
Otherwise EMA is in the same device as the model.
|
||||
"""
|
||||
|
||||
self.decay = ema_decay
|
||||
self.ema_fp32 = ema_fp32
|
||||
self.model = copy.deepcopy(model)
|
||||
self.model.requires_grad_(False)
|
||||
self.skip_keys = skip_keys or set()
|
||||
self.fp32_params = {}
|
||||
|
||||
if device is not None:
|
||||
logging.info(f"Copying EMA model to device {device}")
|
||||
self.model = self.model.to(device=device)
|
||||
|
||||
if self.ema_fp32:
|
||||
self.build_fp32_params()
|
||||
|
||||
self.update_freq_counter = 0
|
||||
|
||||
def build_fp32_params(self, state_dict=None):
|
||||
"""
|
||||
Store a copy of the EMA params in fp32.
|
||||
If state dict is passed, the EMA params is copied from
|
||||
the provided state dict. Otherwise, it is copied from the
|
||||
current EMA model parameters.
|
||||
"""
|
||||
if not self.ema_fp32:
|
||||
raise RuntimeError(
|
||||
"build_fp32_params should not be called if ema_fp32=False. "
|
||||
"Use ema_fp32=True if this is really intended."
|
||||
)
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
def _to_float(t):
|
||||
return t.float() if torch.is_floating_point(t) else t
|
||||
|
||||
for param_key in state_dict:
|
||||
if param_key in self.fp32_params:
|
||||
self.fp32_params[param_key].copy_(state_dict[param_key])
|
||||
else:
|
||||
self.fp32_params[param_key] = _to_float(state_dict[param_key])
|
||||
|
||||
def restore(self, state_dict, build_fp32_params=False):
|
||||
"""Load data from a model spec into EMA model"""
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
if build_fp32_params:
|
||||
self.build_fp32_params(state_dict)
|
||||
|
||||
def set_decay(self, decay):
|
||||
self.decay = decay
|
||||
|
||||
def get_decay(self):
|
||||
return self.decay
|
||||
|
||||
def _step_internal(self, new_model):
|
||||
"""One update of the EMA model based on new model weights"""
|
||||
decay = self.decay
|
||||
|
||||
ema_state_dict = {}
|
||||
ema_params = (
|
||||
self.fp32_params if self.ema_fp32 else self.model.state_dict()
|
||||
)
|
||||
for key, param in new_model.state_dict().items():
|
||||
if isinstance(param, dict):
|
||||
continue
|
||||
try:
|
||||
ema_param = ema_params[key]
|
||||
except KeyError:
|
||||
ema_param = (
|
||||
param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
||||
)
|
||||
|
||||
if param.shape != ema_param.shape:
|
||||
raise ValueError(
|
||||
"incompatible tensor shapes between model param and ema param"
|
||||
+ "{} vs. {}".format(param.shape, ema_param.shape)
|
||||
)
|
||||
|
||||
if "version" in key:
|
||||
# Do not decay a model.version pytorch param
|
||||
continue
|
||||
|
||||
if key in self.skip_keys or ("num_batches_tracked" in key and ema_param.dtype == torch.int64):
|
||||
ema_param = param.to(dtype=ema_param.dtype).clone()
|
||||
ema_params[key].copy_(ema_param)
|
||||
else:
|
||||
ema_param.mul_(decay)
|
||||
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
|
||||
ema_state_dict[key] = ema_param
|
||||
self.restore(ema_state_dict, build_fp32_params=False)
|
||||
|
||||
def step(self, new_model):
|
||||
self._step_internal(new_model)
|
||||
|
||||
def reverse(self, model):
|
||||
"""
|
||||
Load the model parameters from EMA model.
|
||||
Useful for inference or fine-tuning from the EMA model.
|
||||
"""
|
||||
d = self.model.state_dict()
|
||||
if "_ema" in d:
|
||||
del d["_ema"]
|
||||
|
||||
model.load_state_dict(d, strict=False)
|
||||
return model
|
||||
18
funasr_local/modules/data2vec/grad_multiply.py
Normal file
18
funasr_local/modules/data2vec/grad_multiply.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
||||
671
funasr_local/modules/data2vec/multihead_attention.py
Normal file
671
funasr_local/modules/data2vec/multihead_attention.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
from funasr_local.modules.data2vec.quant_noise import quant_noise
|
||||
|
||||
|
||||
class FairseqDropout(nn.Module):
|
||||
def __init__(self, p, module_name=None):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.module_name = module_name
|
||||
self.apply_during_inference = False
|
||||
|
||||
def forward(self, x, inplace: bool = False):
|
||||
if self.p > 0 and (self.training or self.apply_during_inference):
|
||||
return F.dropout(x, p=self.p, training=True, inplace=inplace)
|
||||
else:
|
||||
return x
|
||||
|
||||
def make_generation_fast_(
|
||||
self,
|
||||
name: str,
|
||||
retain_dropout: bool = False,
|
||||
retain_dropout_modules: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
):
|
||||
if retain_dropout:
|
||||
if retain_dropout_modules is not None and self.module_name is None:
|
||||
logging.warning(
|
||||
"Cannot enable dropout during inference for module {} "
|
||||
"because module_name was not set".format(name)
|
||||
)
|
||||
elif (
|
||||
retain_dropout_modules is None # if None, apply to all modules
|
||||
or self.module_name in retain_dropout_modules
|
||||
):
|
||||
logging.info(
|
||||
"Enabling dropout during inference for module: {}".format(name)
|
||||
)
|
||||
self.apply_during_inference = True
|
||||
else:
|
||||
logging.info("Disabling dropout for module: {}".format(name))
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout_module = FairseqDropout(
|
||||
dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert not self.self_attention or self.qkv_same_dim, (
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
|
||||
self.k_proj = quant_noise(
|
||||
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.v_proj = quant_noise(
|
||||
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.q_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
self.out_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
self.onnx_trace = False
|
||||
self.skip_embed_dim_check = False
|
||||
|
||||
def prepare_for_onnx_export_(self):
|
||||
self.onnx_trace = True
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
nn.init.xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
nn.init.xavier_normal_(self.bias_v)
|
||||
|
||||
def _get_reserve_head_index(self, num_heads_to_keep: int):
|
||||
k_proj_heads_norm = []
|
||||
q_proj_heads_norm = []
|
||||
v_proj_heads_norm = []
|
||||
|
||||
for i in range(self.num_heads):
|
||||
start_idx = i * self.head_dim
|
||||
end_idx = (i + 1) * self.head_dim
|
||||
k_proj_heads_norm.append(
|
||||
torch.sum(
|
||||
torch.abs(
|
||||
self.k_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
q_proj_heads_norm.append(
|
||||
torch.sum(
|
||||
torch.abs(
|
||||
self.q_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
v_proj_heads_norm.append(
|
||||
torch.sum(
|
||||
torch.abs(
|
||||
self.v_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
|
||||
heads_norm = []
|
||||
for i in range(self.num_heads):
|
||||
heads_norm.append(
|
||||
k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
|
||||
)
|
||||
|
||||
sorted_head_index = sorted(
|
||||
range(self.num_heads), key=lambda k: heads_norm[k], reverse=True
|
||||
)
|
||||
reserve_head_index = []
|
||||
for i in range(num_heads_to_keep):
|
||||
start = sorted_head_index[i] * self.head_dim
|
||||
end = (sorted_head_index[i] + 1) * self.head_dim
|
||||
reserve_head_index.append((start, end))
|
||||
return reserve_head_index
|
||||
|
||||
def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
|
||||
new_q_weight = []
|
||||
new_q_bias = []
|
||||
new_k_weight = []
|
||||
new_k_bias = []
|
||||
new_v_weight = []
|
||||
new_v_bias = []
|
||||
new_out_proj_weight = []
|
||||
|
||||
for ele in reserve_head_index:
|
||||
start_idx, end_idx = ele
|
||||
new_q_weight.append(
|
||||
self.q_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
|
||||
|
||||
new_k_weight.append(
|
||||
self.k_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
|
||||
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
|
||||
|
||||
new_v_weight.append(
|
||||
self.v_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
||||
|
||||
new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
|
||||
|
||||
new_q_weight = torch.cat(new_q_weight).detach()
|
||||
new_k_weight = torch.cat(new_k_weight).detach()
|
||||
new_v_weight = torch.cat(new_v_weight).detach()
|
||||
new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
|
||||
new_q_weight.requires_grad = True
|
||||
new_k_weight.requires_grad = True
|
||||
new_v_weight.requires_grad = True
|
||||
new_out_proj_weight.requires_grad = True
|
||||
|
||||
new_q_bias = torch.cat(new_q_bias).detach()
|
||||
new_q_bias.requires_grad = True
|
||||
|
||||
new_k_bias = torch.cat(new_k_bias).detach()
|
||||
new_k_bias.requires_grad = True
|
||||
|
||||
new_v_bias = torch.cat(new_v_bias).detach()
|
||||
new_v_bias.requires_grad = True
|
||||
|
||||
self.q_proj.weight = torch.nn.Parameter(new_q_weight)
|
||||
self.q_proj.bias = torch.nn.Parameter(new_q_bias)
|
||||
|
||||
self.k_proj.weight = torch.nn.Parameter(new_k_weight)
|
||||
self.k_proj.bias = torch.nn.Parameter(new_k_bias)
|
||||
|
||||
self.v_proj.weight = torch.nn.Parameter(new_v_weight)
|
||||
self.v_proj.bias = torch.nn.Parameter(new_v_bias)
|
||||
|
||||
self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight)
|
||||
|
||||
self.num_heads = len(reserve_head_index)
|
||||
self.embed_dim = self.head_dim * self.num_heads
|
||||
self.q_proj.out_features = self.embed_dim
|
||||
self.k_proj.out_features = self.embed_dim
|
||||
self.v_proj.out_features = self.embed_dim
|
||||
|
||||
def _set_skip_embed_dim_check(self):
|
||||
self.skip_embed_dim_check = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = True,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
before_softmax: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
before_softmax (bool, optional): return the raw attention
|
||||
weights and values before the attention softmax.
|
||||
need_head_weights (bool, optional): return the attention
|
||||
weights for each head. Implies *need_weights*. Default:
|
||||
return the average attention weights over all heads.
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
if not self.skip_embed_dim_check:
|
||||
assert (
|
||||
embed_dim == self.embed_dim
|
||||
), f"query dim {embed_dim} != {self.embed_dim}"
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, _ = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert key_bsz == bsz
|
||||
assert value is not None
|
||||
assert src_len, bsz == value.shape[:2]
|
||||
|
||||
if (
|
||||
not self.onnx_trace
|
||||
and not is_tpu # don't use PyTorch version on TPUs
|
||||
and incremental_state is None
|
||||
and not static_kv
|
||||
# A workaround for quantization to work. Otherwise JIT compilation
|
||||
# treats bias in linear module as method.
|
||||
and not torch.jit.is_scripting()
|
||||
# The Multihead attention implemented in pytorch forces strong dimension check
|
||||
# for input embedding dimention and K,Q,V projection dimension.
|
||||
# Since pruning will break the dimension check and it is not easy to modify the pytorch API,
|
||||
# it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
|
||||
and not self.skip_embed_dim_check
|
||||
):
|
||||
assert key is not None and value is not None
|
||||
return F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
torch.empty([0]),
|
||||
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout_module.p,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
self.training or self.dropout_module.apply_during_inference,
|
||||
key_padding_mask,
|
||||
need_weights,
|
||||
attn_mask,
|
||||
use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if saved_state is not None and "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert self.encoder_decoder_attention and not self.self_attention
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if self.self_attention:
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
q = self.q_proj(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
|
||||
else:
|
||||
assert key is not None and value is not None
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||
)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if k is not None:
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if v is not None:
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
src_len = k.size(1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
prev_key_padding_mask: Optional[Tensor] = None
|
||||
if "prev_key_padding_mask" in saved_state:
|
||||
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||
assert k is not None and v is not None
|
||||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||
key_padding_mask=key_padding_mask,
|
||||
prev_key_padding_mask=prev_key_padding_mask,
|
||||
batch_size=bsz,
|
||||
src_len=k.size(1),
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||
# In this branch incremental_state is never None
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
assert v is not None
|
||||
src_len += 1
|
||||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
||||
)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
||||
key_padding_mask
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if self.onnx_trace:
|
||||
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if not is_tpu:
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if before_softmax:
|
||||
return attn_weights, v
|
||||
|
||||
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
if self.onnx_trace and attn.size(1) == 1:
|
||||
# when ONNX tracing a single decoder step (sequence length == 1)
|
||||
# the transpose is a no-op copy before view, thus unnecessary
|
||||
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
|
||||
else:
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights: Optional[Tensor] = None
|
||||
if need_weights:
|
||||
attn_weights = attn_weights_float.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).transpose(1, 0)
|
||||
if not need_head_weights:
|
||||
# average attention weights over heads
|
||||
attn_weights = attn_weights.mean(dim=0)
|
||||
|
||||
return attn, attn_weights
|
||||
|
||||
@staticmethod
|
||||
def _append_prev_key_padding_mask(
|
||||
key_padding_mask: Optional[Tensor],
|
||||
prev_key_padding_mask: Optional[Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current
|
||||
# is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
if src_len > prev_key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||
device=prev_key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask.float()
|
||||
elif key_padding_mask is not None:
|
||||
if src_len > key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - key_padding_mask.size(1)),
|
||||
device=key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[filler.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = key_padding_mask.float()
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
return new_key_padding_mask
|
||||
|
||||
@torch.jit.export
|
||||
def reorder_incremental_state(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
new_order: Tensor,
|
||||
):
|
||||
"""Reorder buffered internal state (for incremental generation)."""
|
||||
input_buffer = self._get_input_buffer(incremental_state)
|
||||
if input_buffer is not None:
|
||||
for k in input_buffer.keys():
|
||||
input_buffer_k = input_buffer[k]
|
||||
if input_buffer_k is not None:
|
||||
if self.encoder_decoder_attention and input_buffer_k.size(
|
||||
0
|
||||
) == new_order.size(0):
|
||||
break
|
||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
||||
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
||||
return incremental_state
|
||||
|
||||
def _get_input_buffer(
|
||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||
) -> Dict[str, Optional[Tensor]]:
|
||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||
return empty_result
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
|
||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||
return attn_weights
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
prefix = name + "." if name != "" else ""
|
||||
items_to_add = {}
|
||||
keys_to_remove = []
|
||||
for k in state_dict.keys():
|
||||
if k.endswith(prefix + "in_proj_weight"):
|
||||
# in_proj_weight used to be q + k + v with same dimensions
|
||||
dim = int(state_dict[k].shape[0] / 3)
|
||||
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
||||
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim: 2 * dim]
|
||||
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:]
|
||||
|
||||
keys_to_remove.append(k)
|
||||
|
||||
k_bias = prefix + "in_proj_bias"
|
||||
if k_bias in state_dict.keys():
|
||||
dim = int(state_dict[k].shape[0] / 3)
|
||||
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
||||
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
||||
dim: 2 * dim
|
||||
]
|
||||
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:]
|
||||
|
||||
keys_to_remove.append(prefix + "in_proj_bias")
|
||||
|
||||
for k in keys_to_remove:
|
||||
del state_dict[k]
|
||||
|
||||
for key, value in items_to_add.items():
|
||||
state_dict[key] = value
|
||||
107
funasr_local/modules/data2vec/quant_noise.py
Normal file
107
funasr_local/modules/data2vec/quant_noise.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def quant_noise(module, p, block_size):
|
||||
"""
|
||||
Wraps modules and applies quantization noise to the weights for
|
||||
subsequent quantization with Iterative Product Quantization as
|
||||
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||
|
||||
Args:
|
||||
- module: nn.Module
|
||||
- p: amount of Quantization Noise
|
||||
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||
|
||||
Remarks:
|
||||
- Module weights must have the right sizes wrt the block size
|
||||
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||
- For more detail on how to quantize by blocks with convolutional weights,
|
||||
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||
- We implement the simplest form of noise here as stated in the paper
|
||||
which consists in randomly dropping blocks
|
||||
"""
|
||||
|
||||
# if no quantization noise, don't register hook
|
||||
if p <= 0:
|
||||
return module
|
||||
|
||||
# supported modules
|
||||
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||
|
||||
# test whether module.weight has the right sizes wrt block_size
|
||||
is_conv = module.weight.ndim == 4
|
||||
|
||||
# 2D matrix
|
||||
if not is_conv:
|
||||
assert (
|
||||
module.weight.size(1) % block_size == 0
|
||||
), "Input features must be a multiple of block sizes"
|
||||
|
||||
# 4D matrix
|
||||
else:
|
||||
# 1x1 convolutions
|
||||
if module.kernel_size == (1, 1):
|
||||
assert (
|
||||
module.in_channels % block_size == 0
|
||||
), "Input channels must be a multiple of block sizes"
|
||||
# regular convolutions
|
||||
else:
|
||||
k = module.kernel_size[0] * module.kernel_size[1]
|
||||
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||
|
||||
def _forward_pre_hook(mod, input):
|
||||
# no noise for evaluation
|
||||
if mod.training:
|
||||
if not is_conv:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_features = weight.size(1)
|
||||
out_features = weight.size(0)
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
mask = torch.zeros(
|
||||
in_features // block_size * out_features, device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||
|
||||
else:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_channels = mod.in_channels
|
||||
out_channels = mod.out_channels
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
if mod.kernel_size == (1, 1):
|
||||
mask = torch.zeros(
|
||||
int(in_channels // block_size * out_channels),
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||
else:
|
||||
mask = torch.zeros(
|
||||
weight.size(0), weight.size(1), device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = (
|
||||
mask.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||
)
|
||||
|
||||
# scale weights and apply mask
|
||||
mask = mask.to(
|
||||
torch.bool
|
||||
) # x.bool() is not currently supported in TorchScript
|
||||
s = 1 / (1 - p)
|
||||
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||
|
||||
module.register_forward_pre_hook(_forward_pre_hook)
|
||||
return module
|
||||
156
funasr_local/modules/data2vec/utils.py
Normal file
156
funasr_local/modules/data2vec/utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from funasr_local.modules.data2vec.multihead_attention import MultiheadAttention
|
||||
|
||||
|
||||
class Fp32LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class Fp32GroupNorm(nn.GroupNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.group_norm(
|
||||
input.float(),
|
||||
self.num_groups,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class TransposeLast(nn.Module):
|
||||
def __init__(self, deconstruct_idx=None):
|
||||
super().__init__()
|
||||
self.deconstruct_idx = deconstruct_idx
|
||||
|
||||
def forward(self, x):
|
||||
if self.deconstruct_idx is not None:
|
||||
x = x[self.deconstruct_idx]
|
||||
return x.transpose(-2, -1)
|
||||
|
||||
|
||||
class SamePad(nn.Module):
|
||||
def __init__(self, kernel_size, causal=False):
|
||||
super().__init__()
|
||||
if causal:
|
||||
self.remove = kernel_size - 1
|
||||
else:
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
||||
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
||||
if x is None:
|
||||
return None, 0
|
||||
tsz = x.size(dim)
|
||||
m = tsz / multiple
|
||||
remainder = math.ceil(m) * multiple - tsz
|
||||
if m.is_integer():
|
||||
return x, 0
|
||||
pad_offset = (0,) * (-1 - dim) * 2
|
||||
|
||||
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
||||
|
||||
|
||||
def gelu_accurate(x):
|
||||
if not hasattr(gelu_accurate, "_a"):
|
||||
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||
return (
|
||||
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||
|
||||
|
||||
def get_available_activation_fns():
|
||||
return [
|
||||
"relu",
|
||||
"gelu",
|
||||
"gelu_fast", # deprecated
|
||||
"gelu_accurate",
|
||||
"tanh",
|
||||
"linear",
|
||||
]
|
||||
|
||||
|
||||
def get_activation_fn(activation: str):
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "swish":
|
||||
return torch.nn.SiLU
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
"""
|
||||
Initialize the weights specific to the BERT Model.
|
||||
This overrides the default initializations depending on the specified arguments.
|
||||
1. If normal_init_linear_weights is set then weights of linear
|
||||
layer will be initialized using the normal distribution and
|
||||
bais will be set to the specified value.
|
||||
2. If normal_init_embed_weights is set then weights of embedding
|
||||
layer will be initialized using the normal distribution.
|
||||
3. If normal_init_proj_weights is set then weights of
|
||||
in_project_weight for MultiHeadAttention initialized using
|
||||
the normal distribution (to be validated).
|
||||
"""
|
||||
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
||||
407
funasr_local/modules/data2vec/wav2vec2.py
Normal file
407
funasr_local/modules/data2vec/wav2vec2.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from funasr_local.modules.data2vec import utils
|
||||
from funasr_local.modules.data2vec.multihead_attention import MultiheadAttention
|
||||
|
||||
|
||||
class ConvFeatureExtractionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conv_layers: List[Tuple[int, int, int]],
|
||||
dropout: float = 0.0,
|
||||
mode: str = "default",
|
||||
conv_bias: bool = False,
|
||||
in_d: int = 1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode in {"default", "layer_norm"}
|
||||
|
||||
def block(
|
||||
n_in,
|
||||
n_out,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=False,
|
||||
is_group_norm=False,
|
||||
conv_bias=False,
|
||||
):
|
||||
def make_conv():
|
||||
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
||||
nn.init.kaiming_normal_(conv.weight)
|
||||
return conv
|
||||
|
||||
assert (
|
||||
is_layer_norm and is_group_norm
|
||||
) == False, "layer norm and group norm are exclusive"
|
||||
|
||||
if is_layer_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Sequential(
|
||||
utils.TransposeLast(),
|
||||
utils.Fp32LayerNorm(dim, elementwise_affine=True),
|
||||
utils.TransposeLast(),
|
||||
),
|
||||
nn.GELU(),
|
||||
)
|
||||
elif is_group_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
utils.Fp32GroupNorm(dim, dim, affine=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||||
(dim, k, stride) = cl
|
||||
|
||||
self.conv_layers.append(
|
||||
block(
|
||||
in_d,
|
||||
dim,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=mode == "layer_norm",
|
||||
is_group_norm=mode == "default" and i == 0,
|
||||
conv_bias=conv_bias,
|
||||
)
|
||||
)
|
||||
in_d = dim
|
||||
|
||||
def forward(self, x):
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(1)
|
||||
else:
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_conv_pos(e, k, g):
|
||||
pos_conv = nn.Conv1d(
|
||||
e,
|
||||
e,
|
||||
kernel_size=k,
|
||||
padding=k // 2,
|
||||
groups=g,
|
||||
)
|
||||
dropout = 0
|
||||
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
|
||||
nn.init.normal_(pos_conv.weight, mean=0, std=std)
|
||||
nn.init.constant_(pos_conv.bias, 0)
|
||||
|
||||
pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
|
||||
pos_conv = nn.Sequential(pos_conv, utils.SamePad(k), nn.GELU())
|
||||
|
||||
return pos_conv
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def build_encoder_layer(self):
|
||||
if self.layer_type == "transformer":
|
||||
layer = TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=self.encoder_ffn_embed_dim,
|
||||
num_attention_heads=self.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
activation_dropout=self.activation_dropout,
|
||||
activation_fn=self.activation_fn,
|
||||
layer_norm_first=self.layer_norm_first,
|
||||
)
|
||||
else:
|
||||
logging.error("Only transformer is supported for data2vec now")
|
||||
return layer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# position
|
||||
dropout,
|
||||
encoder_embed_dim,
|
||||
required_seq_len_multiple,
|
||||
pos_conv_depth,
|
||||
conv_pos,
|
||||
conv_pos_groups,
|
||||
# transformer layers
|
||||
layer_type,
|
||||
encoder_layers,
|
||||
encoder_ffn_embed_dim,
|
||||
encoder_attention_heads,
|
||||
attention_dropout,
|
||||
activation_dropout,
|
||||
activation_fn,
|
||||
layer_norm_first,
|
||||
encoder_layerdrop,
|
||||
max_positions,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# position
|
||||
self.dropout = dropout
|
||||
self.embedding_dim = encoder_embed_dim
|
||||
self.required_seq_len_multiple = required_seq_len_multiple
|
||||
if pos_conv_depth > 1:
|
||||
num_layers = pos_conv_depth
|
||||
k = max(3, conv_pos // num_layers)
|
||||
|
||||
def make_conv_block(e, k, g, l):
|
||||
return nn.Sequential(
|
||||
*[
|
||||
nn.Sequential(
|
||||
nn.Conv1d(
|
||||
e,
|
||||
e,
|
||||
kernel_size=k,
|
||||
padding=k // 2,
|
||||
groups=g,
|
||||
),
|
||||
utils.SamePad(k),
|
||||
utils.TransposeLast(),
|
||||
torch.nn.LayerNorm(e, elementwise_affine=False),
|
||||
utils.TransposeLast(),
|
||||
nn.GELU(),
|
||||
)
|
||||
for _ in range(l)
|
||||
]
|
||||
)
|
||||
|
||||
self.pos_conv = make_conv_block(
|
||||
self.embedding_dim, k, conv_pos_groups, num_layers
|
||||
)
|
||||
|
||||
else:
|
||||
self.pos_conv = make_conv_pos(
|
||||
self.embedding_dim,
|
||||
conv_pos,
|
||||
conv_pos_groups,
|
||||
)
|
||||
|
||||
# transformer layers
|
||||
self.layer_type = layer_type
|
||||
self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_fn = activation_fn
|
||||
self.layer_norm_first = layer_norm_first
|
||||
self.layerdrop = encoder_layerdrop
|
||||
self.max_positions = max_positions
|
||||
self.layers = nn.ModuleList(
|
||||
[self.build_encoder_layer() for _ in range(encoder_layers)]
|
||||
)
|
||||
self.layer_norm = torch.nn.LayerNorm(self.embedding_dim)
|
||||
|
||||
self.apply(utils.init_bert_params)
|
||||
|
||||
def forward(self, x, padding_mask=None, layer=None):
|
||||
x, layer_results = self.extract_features(x, padding_mask, layer)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
x,
|
||||
padding_mask=None,
|
||||
tgt_layer=None,
|
||||
min_layer=0,
|
||||
):
|
||||
|
||||
if padding_mask is not None:
|
||||
x[padding_mask] = 0
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# pad to the sequence length dimension
|
||||
x, pad_length = utils.pad_to_multiple(
|
||||
x, self.required_seq_len_multiple, dim=-2, value=0
|
||||
)
|
||||
if pad_length > 0 and padding_mask is None:
|
||||
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
||||
padding_mask[:, -pad_length:] = True
|
||||
else:
|
||||
padding_mask, _ = utils.pad_to_multiple(
|
||||
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
r = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask)
|
||||
if i >= min_layer:
|
||||
layer_results.append((x, z, lr))
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# undo paddding
|
||||
if pad_length > 0:
|
||||
x = x[:, :-pad_length]
|
||||
|
||||
def undo_pad(a, b, c):
|
||||
return (
|
||||
a[:-pad_length],
|
||||
b[:-pad_length] if b is not None else b,
|
||||
c[:-pad_length],
|
||||
)
|
||||
|
||||
layer_results = [undo_pad(*u) for u in layer_results]
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum output length supported by the encoder."""
|
||||
return self.max_positions
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
return state_dict
|
||||
|
||||
|
||||
class TransformerSentenceEncoderLayer(nn.Module):
|
||||
"""
|
||||
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
||||
models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
ffn_embedding_dim: int = 3072,
|
||||
num_attention_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
layer_norm_first: bool = False,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
# Initialize parameters
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
# Initialize blocks
|
||||
self.activation_fn = utils.get_activation_fn(activation_fn)
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
self_attention=True,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.layer_norm_first = layer_norm_first
|
||||
|
||||
# layer norm associated with the self attention layer
|
||||
self.self_attn_layer_norm = torch.nn.LayerNorm(self.embedding_dim)
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
|
||||
# layer norm associated with the position wise feed-forward NN
|
||||
self.final_layer_norm = torch.nn.LayerNorm(self.embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (T, B, C)
|
||||
self_attn_mask: torch.Tensor = None,
|
||||
self_attn_padding_mask: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
LayerNorm is applied either before or after the self-attention/ffn
|
||||
modules similar to the original Transformer imlementation.
|
||||
"""
|
||||
residual = x
|
||||
|
||||
if self.layer_norm_first:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
attn_mask=self_attn_mask,
|
||||
need_weights=False,
|
||||
)
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
|
||||
layer_result = x
|
||||
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
else:
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=False,
|
||||
)
|
||||
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
|
||||
layer_result = x
|
||||
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, (attn, layer_result)
|
||||
125
funasr_local/modules/dynamic_conv.py
Normal file
125
funasr_local/modules/dynamic_conv.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Dynamic Convolution module."""
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
MIN_VALUE = float(numpy.finfo(numpy.float32).min)
|
||||
|
||||
|
||||
class DynamicConvolution(nn.Module):
|
||||
"""Dynamic Convolution layer.
|
||||
|
||||
This implementation is based on
|
||||
https://github.com/pytorch/fairseq/tree/master/fairseq
|
||||
|
||||
Args:
|
||||
wshare (int): the number of kernel of convolution
|
||||
n_feat (int): the number of features
|
||||
dropout_rate (float): dropout_rate
|
||||
kernel_size (int): kernel size (length)
|
||||
use_kernel_mask (bool): Use causal mask or not for convolution kernel
|
||||
use_bias (bool): Use bias term or not.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wshare,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
use_kernel_mask=False,
|
||||
use_bias=False,
|
||||
):
|
||||
"""Construct Dynamic Convolution layer."""
|
||||
super(DynamicConvolution, self).__init__()
|
||||
|
||||
assert n_feat % wshare == 0
|
||||
self.wshare = wshare
|
||||
self.use_kernel_mask = use_kernel_mask
|
||||
self.dropout_rate = dropout_rate
|
||||
self.kernel_size = kernel_size
|
||||
self.attn = None
|
||||
|
||||
# linear -> GLU -- -> lightconv -> linear
|
||||
# \ /
|
||||
# Linear
|
||||
self.linear1 = nn.Linear(n_feat, n_feat * 2)
|
||||
self.linear2 = nn.Linear(n_feat, n_feat)
|
||||
self.linear_weight = nn.Linear(n_feat, self.wshare * 1 * kernel_size)
|
||||
nn.init.xavier_uniform(self.linear_weight.weight)
|
||||
self.act = nn.GLU()
|
||||
|
||||
# dynamic conv related
|
||||
self.use_bias = use_bias
|
||||
if self.use_bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(n_feat))
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Forward of 'Dynamic Convolution'.
|
||||
|
||||
This function takes query, key and value but uses only quert.
|
||||
This is just for compatibility with self-attention layer (attention.py)
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): (batch, time1, d_model) input tensor
|
||||
key (torch.Tensor): (batch, time2, d_model) NOT USED
|
||||
value (torch.Tensor): (batch, time2, d_model) NOT USED
|
||||
mask (torch.Tensor): (batch, time1, time2) mask
|
||||
|
||||
Return:
|
||||
x (torch.Tensor): (batch, time1, d_model) output
|
||||
|
||||
"""
|
||||
# linear -> GLU -- -> lightconv -> linear
|
||||
# \ /
|
||||
# Linear
|
||||
x = query
|
||||
B, T, C = x.size()
|
||||
H = self.wshare
|
||||
k = self.kernel_size
|
||||
|
||||
# first liner layer
|
||||
x = self.linear1(x)
|
||||
|
||||
# GLU activation
|
||||
x = self.act(x)
|
||||
|
||||
# get kernel of convolution
|
||||
weight = self.linear_weight(x) # B x T x kH
|
||||
weight = F.dropout(weight, self.dropout_rate, training=self.training)
|
||||
weight = weight.view(B, T, H, k).transpose(1, 2).contiguous() # B x H x T x k
|
||||
weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype)
|
||||
weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf"))
|
||||
weight_new = weight_new.to(x.device) # B x H x T x T+k-1
|
||||
weight_new.as_strided(
|
||||
(B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)
|
||||
).copy_(weight)
|
||||
weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k)
|
||||
if self.use_kernel_mask:
|
||||
kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0)
|
||||
weight_new = weight_new.masked_fill(kernel_mask == 0.0, float("-inf"))
|
||||
weight_new = F.softmax(weight_new, dim=-1)
|
||||
self.attn = weight_new
|
||||
weight_new = weight_new.view(B * H, T, T)
|
||||
|
||||
# convolution
|
||||
x = x.transpose(1, 2).contiguous() # B x C x T
|
||||
x = x.view(B * H, int(C / H), T).transpose(1, 2)
|
||||
x = torch.bmm(weight_new, x) # BH x T x C/H
|
||||
x = x.transpose(1, 2).contiguous().view(B, C, T)
|
||||
|
||||
if self.use_bias:
|
||||
x = x + self.bias.view(1, -1, 1)
|
||||
x = x.transpose(1, 2) # B x T x C
|
||||
|
||||
if mask is not None and not self.use_kernel_mask:
|
||||
mask = mask.transpose(-1, -2)
|
||||
x = x.masked_fill(mask == 0, 0.0)
|
||||
|
||||
# second linear layer
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
138
funasr_local/modules/dynamic_conv2d.py
Normal file
138
funasr_local/modules/dynamic_conv2d.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Dynamic 2-Dimensional Convolution module."""
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
MIN_VALUE = float(numpy.finfo(numpy.float32).min)
|
||||
|
||||
|
||||
class DynamicConvolution2D(nn.Module):
|
||||
"""Dynamic 2-Dimensional Convolution layer.
|
||||
|
||||
This implementation is based on
|
||||
https://github.com/pytorch/fairseq/tree/master/fairseq
|
||||
|
||||
Args:
|
||||
wshare (int): the number of kernel of convolution
|
||||
n_feat (int): the number of features
|
||||
dropout_rate (float): dropout_rate
|
||||
kernel_size (int): kernel size (length)
|
||||
use_kernel_mask (bool): Use causal mask or not for convolution kernel
|
||||
use_bias (bool): Use bias term or not.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wshare,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
use_kernel_mask=False,
|
||||
use_bias=False,
|
||||
):
|
||||
"""Construct Dynamic 2-Dimensional Convolution layer."""
|
||||
super(DynamicConvolution2D, self).__init__()
|
||||
|
||||
assert n_feat % wshare == 0
|
||||
self.wshare = wshare
|
||||
self.use_kernel_mask = use_kernel_mask
|
||||
self.dropout_rate = dropout_rate
|
||||
self.kernel_size = kernel_size
|
||||
self.padding_size = int(kernel_size / 2)
|
||||
self.attn_t = None
|
||||
self.attn_f = None
|
||||
|
||||
# linear -> GLU -- -> lightconv -> linear
|
||||
# \ /
|
||||
# Linear
|
||||
self.linear1 = nn.Linear(n_feat, n_feat * 2)
|
||||
self.linear2 = nn.Linear(n_feat * 2, n_feat)
|
||||
self.linear_weight = nn.Linear(n_feat, self.wshare * 1 * kernel_size)
|
||||
nn.init.xavier_uniform(self.linear_weight.weight)
|
||||
self.linear_weight_f = nn.Linear(n_feat, kernel_size)
|
||||
nn.init.xavier_uniform(self.linear_weight_f.weight)
|
||||
self.act = nn.GLU()
|
||||
|
||||
# dynamic conv related
|
||||
self.use_bias = use_bias
|
||||
if self.use_bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(n_feat))
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Forward of 'Dynamic 2-Dimensional Convolution'.
|
||||
|
||||
This function takes query, key and value but uses only query.
|
||||
This is just for compatibility with self-attention layer (attention.py)
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): (batch, time1, d_model) input tensor
|
||||
key (torch.Tensor): (batch, time2, d_model) NOT USED
|
||||
value (torch.Tensor): (batch, time2, d_model) NOT USED
|
||||
mask (torch.Tensor): (batch, time1, time2) mask
|
||||
|
||||
Return:
|
||||
x (torch.Tensor): (batch, time1, d_model) output
|
||||
|
||||
"""
|
||||
# linear -> GLU -- -> lightconv -> linear
|
||||
# \ /
|
||||
# Linear
|
||||
x = query
|
||||
B, T, C = x.size()
|
||||
H = self.wshare
|
||||
k = self.kernel_size
|
||||
|
||||
# first liner layer
|
||||
x = self.linear1(x)
|
||||
|
||||
# GLU activation
|
||||
x = self.act(x)
|
||||
|
||||
# convolution of frequency axis
|
||||
weight_f = self.linear_weight_f(x).view(B * T, 1, k) # B x T x k
|
||||
self.attn_f = weight_f.view(B, T, k).unsqueeze(1)
|
||||
xf = F.conv1d(
|
||||
x.view(1, B * T, C), weight_f, padding=self.padding_size, groups=B * T
|
||||
)
|
||||
xf = xf.view(B, T, C)
|
||||
|
||||
# get kernel of convolution
|
||||
weight = self.linear_weight(x) # B x T x kH
|
||||
weight = F.dropout(weight, self.dropout_rate, training=self.training)
|
||||
weight = weight.view(B, T, H, k).transpose(1, 2).contiguous() # B x H x T x k
|
||||
weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype)
|
||||
weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf"))
|
||||
weight_new = weight_new.to(x.device) # B x H x T x T+k-1
|
||||
weight_new.as_strided(
|
||||
(B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)
|
||||
).copy_(weight)
|
||||
weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k)
|
||||
if self.use_kernel_mask:
|
||||
kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0)
|
||||
weight_new = weight_new.masked_fill(kernel_mask == 0.0, float("-inf"))
|
||||
weight_new = F.softmax(weight_new, dim=-1)
|
||||
self.attn_t = weight_new
|
||||
weight_new = weight_new.view(B * H, T, T)
|
||||
|
||||
# convolution
|
||||
x = x.transpose(1, 2).contiguous() # B x C x T
|
||||
x = x.view(B * H, int(C / H), T).transpose(1, 2)
|
||||
x = torch.bmm(weight_new, x)
|
||||
x = x.transpose(1, 2).contiguous().view(B, C, T)
|
||||
|
||||
if self.use_bias:
|
||||
x = x + self.bias.view(1, -1, 1)
|
||||
x = x.transpose(1, 2) # B x T x C
|
||||
x = torch.cat((x, xf), -1) # B x T x Cx2
|
||||
|
||||
if mask is not None and not self.use_kernel_mask:
|
||||
mask = mask.transpose(-1, -2)
|
||||
x = x.masked_fill(mask == 0, 0.0)
|
||||
|
||||
# second linear layer
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
403
funasr_local/modules/e2e_asr_common.py
Normal file
403
funasr_local/modules/e2e_asr_common.py
Normal file
@@ -0,0 +1,403 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: utf-8
|
||||
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Common functions for ASR."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from itertools import groupby
|
||||
import numpy as np
|
||||
import six
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search_transducer import BeamSearchTransducer
|
||||
from funasr_local.models.joint_net.joint_network import JointNetwork
|
||||
|
||||
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||
"""End detection.
|
||||
|
||||
described in Eq. (50) of S. Watanabe et al
|
||||
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||
|
||||
:param ended_hyps:
|
||||
:param i:
|
||||
:param M:
|
||||
:param D_end:
|
||||
:return:
|
||||
"""
|
||||
if len(ended_hyps) == 0:
|
||||
return False
|
||||
count = 0
|
||||
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
|
||||
for m in six.moves.range(M):
|
||||
# get ended_hyps with their length is i - m
|
||||
hyp_length = i - m
|
||||
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
|
||||
if len(hyps_same_length) > 0:
|
||||
best_hyp_same_length = sorted(
|
||||
hyps_same_length, key=lambda x: x["score"], reverse=True
|
||||
)[0]
|
||||
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
|
||||
count += 1
|
||||
|
||||
if count == M:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# TODO(takaaki-hori): add different smoothing methods
|
||||
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||
"""Obtain label distribution for loss smoothing.
|
||||
|
||||
:param odim:
|
||||
:param lsm_type:
|
||||
:param blank:
|
||||
:param transcript:
|
||||
:return:
|
||||
"""
|
||||
if transcript is not None:
|
||||
with open(transcript, "rb") as f:
|
||||
trans_json = json.load(f)["utts"]
|
||||
|
||||
if lsm_type == "unigram":
|
||||
assert transcript is not None, (
|
||||
"transcript is required for %s label smoothing" % lsm_type
|
||||
)
|
||||
labelcount = np.zeros(odim)
|
||||
for k, v in trans_json.items():
|
||||
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
|
||||
# to avoid an error when there is no text in an uttrance
|
||||
if len(ids) > 0:
|
||||
labelcount[ids] += 1
|
||||
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||
labelcount[labelcount == 0] = 1 # flooring
|
||||
labelcount[blank] = 0 # remove counts for blank
|
||||
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||
else:
|
||||
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
|
||||
sys.exit()
|
||||
|
||||
return labeldist
|
||||
|
||||
|
||||
def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
|
||||
"""Return the output size of the VGG frontend.
|
||||
|
||||
:param in_channel: input channel size
|
||||
:param out_channel: output channel size
|
||||
:return: output size
|
||||
:rtype int
|
||||
"""
|
||||
idim = idim / in_channel
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
||||
return int(idim) * out_channel # numer of channels
|
||||
|
||||
|
||||
class ErrorCalculator(object):
|
||||
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
||||
|
||||
:param y_hats: numpy array with predicted text
|
||||
:param y_pads: numpy array with true (target) text
|
||||
:param char_list:
|
||||
:param sym_space:
|
||||
:param sym_blank:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
|
||||
):
|
||||
"""Construct an ErrorCalculator object."""
|
||||
super(ErrorCalculator, self).__init__()
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
self.char_list = char_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
self.idx_blank = self.char_list.index(self.blank)
|
||||
if self.space in self.char_list:
|
||||
self.idx_space = self.char_list.index(self.space)
|
||||
else:
|
||||
self.idx_space = None
|
||||
|
||||
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
||||
"""Calculate sentence-level WER/CER score.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:param bool is_ctc: calculate CER score for CTC
|
||||
:return: sentence-level WER score
|
||||
:rtype float
|
||||
:return: sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cer, wer = None, None
|
||||
if is_ctc:
|
||||
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
||||
elif not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
return cer, wer
|
||||
|
||||
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
||||
"""Calculate sentence-level CER score for CTC.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
cers, char_ref_lens = [], []
|
||||
for i, y in enumerate(ys_hat):
|
||||
y_hat = [x[0] for x in groupby(y)]
|
||||
y_true = ys_pad[i]
|
||||
seq_hat, seq_true = [], []
|
||||
for idx in y_hat:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_hat.append(self.char_list[int(idx)])
|
||||
|
||||
for idx in y_true:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_true.append(self.char_list[int(idx)])
|
||||
|
||||
hyp_chars = "".join(seq_hat)
|
||||
ref_chars = "".join(seq_true)
|
||||
if len(ref_chars) > 0:
|
||||
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
||||
return cer_ctc
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
||||
:return: token list of prediction
|
||||
:rtype list
|
||||
:return: token list of reference
|
||||
:rtype list
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
# NOTE: padding index (-1) in y_true is used to pad y_hat
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, "")
|
||||
seq_true_text = "".join(seq_true).replace(self.space, " ")
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
char_eds, char_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(" ", "")
|
||||
ref_chars = seq_true_text.replace(" ", "")
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level WER score
|
||||
:rtype float
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
word_eds, word_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
|
||||
class ErrorCalculatorTransducer:
|
||||
"""Calculate CER and WER for transducer models.
|
||||
Args:
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
token_list: List of token units.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank symbol.
|
||||
report_cer: Whether to compute CER.
|
||||
report_wer: Whether to compute WER.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
joint_network: JointNetwork,
|
||||
token_list: List[int],
|
||||
sym_space: str,
|
||||
sym_blank: str,
|
||||
report_cer: bool = False,
|
||||
report_wer: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ErrorCalculatorTransducer object."""
|
||||
super().__init__()
|
||||
|
||||
self.beam_search = BeamSearchTransducer(
|
||||
decoder=decoder,
|
||||
joint_network=joint_network,
|
||||
beam_size=1,
|
||||
search_type="default",
|
||||
score_norm=False,
|
||||
)
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.token_list = token_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
def __call__(
|
||||
self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor,
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Calculate sentence-level WER or/and CER score for Transducer model.
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
encoder_out_lens: Encoder output sequences length. (B,)
|
||||
Returns:
|
||||
: Sentence-level CER score.
|
||||
: Sentence-level WER score.
|
||||
"""
|
||||
cer, wer = None, None
|
||||
|
||||
batchsize = int(encoder_out.size(0))
|
||||
|
||||
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
|
||||
|
||||
batch_nbest = [
|
||||
self.beam_search(encoder_out[b][: encoder_out_lens[b]])
|
||||
for b in range(batchsize)
|
||||
]
|
||||
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
|
||||
|
||||
char_pred, char_target = self.convert_to_char(pred, target)
|
||||
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(char_pred, char_target)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(char_pred, char_target)
|
||||
|
||||
return cer, wer
|
||||
|
||||
def convert_to_char(
|
||||
self, pred: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[List, List]:
|
||||
"""Convert label ID sequences to character sequences.
|
||||
Args:
|
||||
pred: Prediction label ID sequences. (B, U)
|
||||
target: Target label ID sequences. (B, L)
|
||||
Returns:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
"""
|
||||
char_pred, char_target = [], []
|
||||
|
||||
for i, pred_i in enumerate(pred):
|
||||
char_pred_i = [self.token_list[int(h)] for h in pred_i]
|
||||
char_target_i = [self.token_list[int(r)] for r in target[i]]
|
||||
|
||||
char_pred_i = "".join(char_pred_i).replace(self.space, " ")
|
||||
char_pred_i = char_pred_i.replace(self.blank, "")
|
||||
|
||||
char_target_i = "".join(char_target_i).replace(self.space, " ")
|
||||
char_target_i = char_target_i.replace(self.blank, "")
|
||||
|
||||
char_pred.append(char_pred_i)
|
||||
char_target.append(char_target_i)
|
||||
|
||||
return char_pred, char_target
|
||||
|
||||
def calculate_cer(
|
||||
self, char_pred: torch.Tensor, char_target: torch.Tensor
|
||||
) -> float:
|
||||
"""Calculate sentence-level CER score.
|
||||
Args:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
Returns:
|
||||
: Average sentence-level CER score.
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
distances, lens = [], []
|
||||
|
||||
for i, char_pred_i in enumerate(char_pred):
|
||||
pred = char_pred_i.replace(" ", "")
|
||||
target = char_target[i].replace(" ", "")
|
||||
distances.append(editdistance.eval(pred, target))
|
||||
lens.append(len(target))
|
||||
|
||||
return float(sum(distances)) / sum(lens)
|
||||
|
||||
def calculate_wer(
|
||||
self, char_pred: torch.Tensor, char_target: torch.Tensor
|
||||
) -> float:
|
||||
"""Calculate sentence-level WER score.
|
||||
Args:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
Returns:
|
||||
: Average sentence-level WER score
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
distances, lens = [], []
|
||||
|
||||
for i, char_pred_i in enumerate(char_pred):
|
||||
pred = char_pred_i.replace("▁", " ").split()
|
||||
target = char_target[i].replace("▁", " ").split()
|
||||
|
||||
distances.append(editdistance.eval(pred, target))
|
||||
lens.append(len(target))
|
||||
|
||||
return float(sum(distances)) / sum(lens)
|
||||
0
funasr_local/modules/eend_ola/__init__.py
Normal file
0
funasr_local/modules/eend_ola/__init__.py
Normal file
133
funasr_local/modules/eend_ola/encoder.py
Normal file
133
funasr_local/modules/eend_ola/encoder.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MultiHeadSelfAttention(nn.Module):
|
||||
def __init__(self, n_units, h=8, dropout_rate=0.1):
|
||||
super(MultiHeadSelfAttention, self).__init__()
|
||||
self.linearQ = nn.Linear(n_units, n_units)
|
||||
self.linearK = nn.Linear(n_units, n_units)
|
||||
self.linearV = nn.Linear(n_units, n_units)
|
||||
self.linearO = nn.Linear(n_units, n_units)
|
||||
self.d_k = n_units // h
|
||||
self.h = h
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def __call__(self, x, batch_size, x_mask):
|
||||
q = self.linearQ(x).view(batch_size, -1, self.h, self.d_k)
|
||||
k = self.linearK(x).view(batch_size, -1, self.h, self.d_k)
|
||||
v = self.linearV(x).view(batch_size, -1, self.h, self.d_k)
|
||||
scores = torch.matmul(
|
||||
q.permute(0, 2, 1, 3), k.permute(0, 2, 3, 1)) / math.sqrt(self.d_k)
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask.unsqueeze(1)
|
||||
scores = scores.masked_fill(x_mask == 0, -1e9)
|
||||
self.att = F.softmax(scores, dim=3)
|
||||
p_att = self.dropout(self.att)
|
||||
x = torch.matmul(p_att, v.permute(0, 2, 1, 3))
|
||||
x = x.permute(0, 2, 1, 3).contiguous().view(-1, self.h * self.d_k)
|
||||
return self.linearO(x)
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
def __init__(self, n_units, d_units, dropout_rate):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.linear1 = nn.Linear(n_units, d_units)
|
||||
self.linear2 = nn.Linear(d_units, n_units)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.linear2(self.dropout(F.relu(self.linear1(x))))
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.reverse = reverse
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x):
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class EENDOLATransformerEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
idim: int,
|
||||
n_layers: int,
|
||||
n_units: int,
|
||||
e_units: int = 2048,
|
||||
h: int = 4,
|
||||
dropout_rate: float = 0.1,
|
||||
use_pos_emb: bool = False):
|
||||
super(EENDOLATransformerEncoder, self).__init__()
|
||||
self.lnorm_in = nn.LayerNorm(n_units)
|
||||
self.n_layers = n_layers
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
for i in range(n_layers):
|
||||
setattr(self, '{}{:d}'.format("lnorm1_", i),
|
||||
nn.LayerNorm(n_units))
|
||||
setattr(self, '{}{:d}'.format("self_att_", i),
|
||||
MultiHeadSelfAttention(n_units, h))
|
||||
setattr(self, '{}{:d}'.format("lnorm2_", i),
|
||||
nn.LayerNorm(n_units))
|
||||
setattr(self, '{}{:d}'.format("ff_", i),
|
||||
PositionwiseFeedForward(n_units, e_units, dropout_rate))
|
||||
self.lnorm_out = nn.LayerNorm(n_units)
|
||||
if use_pos_emb:
|
||||
self.pos_enc = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, n_units),
|
||||
torch.nn.LayerNorm(n_units),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
PositionalEncoding(n_units, dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.linear_in = nn.Linear(idim, n_units)
|
||||
self.pos_enc = None
|
||||
|
||||
def __call__(self, x, x_mask=None):
|
||||
BT_size = x.shape[0] * x.shape[1]
|
||||
if self.pos_enc is not None:
|
||||
e = self.pos_enc(x)
|
||||
e = e.view(BT_size, -1)
|
||||
else:
|
||||
e = self.linear_in(x.reshape(BT_size, -1))
|
||||
for i in range(self.n_layers):
|
||||
e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
|
||||
s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
|
||||
e = e + self.dropout(s)
|
||||
e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
|
||||
s = getattr(self, '{}{:d}'.format("ff_", i))(e)
|
||||
e = e + self.dropout(s)
|
||||
return self.lnorm_out(e)
|
||||
50
funasr_local/modules/eend_ola/encoder_decoder_attractor.py
Normal file
50
funasr_local/modules/eend_ola/encoder_decoder_attractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class EncoderDecoderAttractor(nn.Module):
|
||||
|
||||
def __init__(self, n_units, encoder_dropout=0.1, decoder_dropout=0.1):
|
||||
super(EncoderDecoderAttractor, self).__init__()
|
||||
self.enc0_dropout = nn.Dropout(encoder_dropout)
|
||||
self.encoder = nn.LSTM(n_units, n_units, 1, batch_first=True, dropout=encoder_dropout)
|
||||
self.dec0_dropout = nn.Dropout(decoder_dropout)
|
||||
self.decoder = nn.LSTM(n_units, n_units, 1, batch_first=True, dropout=decoder_dropout)
|
||||
self.counter = nn.Linear(n_units, 1)
|
||||
self.n_units = n_units
|
||||
|
||||
def forward_core(self, xs, zeros):
|
||||
ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.int64)
|
||||
xs = [self.enc0_dropout(x) for x in xs]
|
||||
xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
|
||||
xs = nn.utils.rnn.pack_padded_sequence(xs, ilens, batch_first=True, enforce_sorted=False)
|
||||
_, (hx, cx) = self.encoder(xs)
|
||||
zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.int64)
|
||||
max_zlen = torch.max(zlens).to(torch.int).item()
|
||||
zeros = [self.enc0_dropout(z) for z in zeros]
|
||||
zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1)
|
||||
zeros = nn.utils.rnn.pack_padded_sequence(zeros, zlens, batch_first=True, enforce_sorted=False)
|
||||
attractors, (_, _) = self.decoder(zeros, (hx, cx))
|
||||
attractors = nn.utils.rnn.pad_packed_sequence(attractors, batch_first=True, padding_value=-1,
|
||||
total_length=max_zlen)[0]
|
||||
attractors = [att[:zlens[i].to(torch.int).item()] for i, att in enumerate(attractors)]
|
||||
return attractors
|
||||
|
||||
def forward(self, xs, n_speakers):
|
||||
zeros = [torch.zeros(n_spk + 1, self.n_units).to(torch.float32).to(xs[0].device) for n_spk in n_speakers]
|
||||
attractors = self.forward_core(xs, zeros)
|
||||
labels = torch.cat([torch.from_numpy(np.array([[1] * n_spk + [0]], np.float32)) for n_spk in n_speakers], dim=1)
|
||||
labels = labels.to(xs[0].device)
|
||||
logit = torch.cat([self.counter(att).view(-1, n_spk + 1) for att, n_spk in zip(attractors, n_speakers)], dim=1)
|
||||
loss = F.binary_cross_entropy(torch.sigmoid(logit), labels)
|
||||
|
||||
attractors = [att[slice(0, att.shape[0] - 1)] for att in attractors]
|
||||
return loss, attractors
|
||||
|
||||
def estimate(self, xs, max_n_speakers=15):
|
||||
zeros = [torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs]
|
||||
attractors = self.forward_core(xs, zeros)
|
||||
probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors]
|
||||
return attractors, probs
|
||||
67
funasr_local/modules/eend_ola/utils/losses.py
Normal file
67
funasr_local/modules/eend_ola/utils/losses.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from itertools import permutations
|
||||
from torch import nn
|
||||
|
||||
|
||||
def standard_loss(ys, ts, label_delay=0):
|
||||
losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
|
||||
loss = torch.sum(torch.stack(losses))
|
||||
n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
|
||||
loss = loss / n_frames
|
||||
return loss
|
||||
|
||||
|
||||
def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):
|
||||
max_n_speakers = ts[0].shape[1]
|
||||
olens = [y.shape[0] for y in ys]
|
||||
ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1)
|
||||
ys_mask = [torch.ones(olen).to(ys.device) for olen in olens]
|
||||
ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1)
|
||||
|
||||
losses = []
|
||||
for shift in range(max_n_speakers):
|
||||
ts_roll = [torch.roll(t, -shift, dims=1) for t in ts]
|
||||
ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1)
|
||||
loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none')
|
||||
if ys_mask is not None:
|
||||
loss = loss * ys_mask
|
||||
loss = torch.sum(loss, dim=1)
|
||||
losses.append(loss)
|
||||
losses = torch.stack(losses, dim=2)
|
||||
|
||||
perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32)
|
||||
perms = torch.from_numpy(perms).to(losses.device)
|
||||
y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device)
|
||||
t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long)
|
||||
|
||||
losses_perm = []
|
||||
for t_ind in t_inds:
|
||||
losses_perm.append(
|
||||
torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1))
|
||||
losses_perm = torch.stack(losses_perm, dim=1)
|
||||
|
||||
def select_perm_indices(num, max_num):
|
||||
perms = list(permutations(range(max_num)))
|
||||
sub_perms = list(permutations(range(num)))
|
||||
return [
|
||||
[x[:num] for x in perms].index(perm)
|
||||
for perm in sub_perms]
|
||||
|
||||
masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf'))
|
||||
for i, t in enumerate(ts):
|
||||
n_speakers = n_speakers_list[i]
|
||||
indices = select_perm_indices(n_speakers, max_n_speakers)
|
||||
masks[i, indices] = 0
|
||||
losses_perm += masks
|
||||
|
||||
min_loss = torch.sum(torch.min(losses_perm, dim=1)[0])
|
||||
n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device)
|
||||
min_loss = min_loss / n_frames
|
||||
|
||||
min_indices = torch.argmin(losses_perm, dim=1)
|
||||
labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)]
|
||||
labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)]
|
||||
|
||||
return min_loss, labels_perm
|
||||
95
funasr_local/modules/eend_ola/utils/power.py
Normal file
95
funasr_local/modules/eend_ola/utils/power.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torch.nn.functional as F
|
||||
from itertools import combinations
|
||||
from itertools import permutations
|
||||
|
||||
|
||||
def generate_mapping_dict(max_speaker_num=6, max_olp_speaker_num=3):
|
||||
all_kinds = []
|
||||
all_kinds.append(0)
|
||||
for i in range(max_olp_speaker_num):
|
||||
selected_num = i + 1
|
||||
coms = np.array(list(combinations(np.arange(max_speaker_num), selected_num)))
|
||||
for com in coms:
|
||||
tmp = np.zeros(max_speaker_num)
|
||||
tmp[com] = 1
|
||||
item = int(raw_dec_trans(tmp.reshape(1, -1), max_speaker_num)[0])
|
||||
all_kinds.append(item)
|
||||
all_kinds_order = sorted(all_kinds)
|
||||
|
||||
mapping_dict = {}
|
||||
mapping_dict['dec2label'] = {}
|
||||
mapping_dict['label2dec'] = {}
|
||||
for i in range(len(all_kinds_order)):
|
||||
dec = all_kinds_order[i]
|
||||
mapping_dict['dec2label'][dec] = i
|
||||
mapping_dict['label2dec'][i] = dec
|
||||
oov_id = len(all_kinds_order)
|
||||
mapping_dict['oov'] = oov_id
|
||||
return mapping_dict
|
||||
|
||||
|
||||
def raw_dec_trans(x, max_speaker_num):
|
||||
num_list = []
|
||||
for i in range(max_speaker_num):
|
||||
num_list.append(x[:, i])
|
||||
base = 1
|
||||
T = x.shape[0]
|
||||
res = np.zeros((T))
|
||||
for num in num_list:
|
||||
res += num * base
|
||||
base = base * 2
|
||||
return res
|
||||
|
||||
|
||||
def mapping_func(num, mapping_dict):
|
||||
if num in mapping_dict['dec2label'].keys():
|
||||
label = mapping_dict['dec2label'][num]
|
||||
else:
|
||||
label = mapping_dict['oov']
|
||||
return label
|
||||
|
||||
|
||||
def dec_trans(x, max_speaker_num, mapping_dict):
|
||||
num_list = []
|
||||
for i in range(max_speaker_num):
|
||||
num_list.append(x[:, i])
|
||||
base = 1
|
||||
T = x.shape[0]
|
||||
res = np.zeros((T))
|
||||
for num in num_list:
|
||||
res += num * base
|
||||
base = base * 2
|
||||
res = np.array([mapping_func(i, mapping_dict) for i in res])
|
||||
return res
|
||||
|
||||
|
||||
def create_powerlabel(label, mapping_dict, max_speaker_num=6, max_olp_speaker_num=3):
|
||||
T, C = label.shape
|
||||
padding_label = np.zeros((T, max_speaker_num))
|
||||
padding_label[:, :C] = label
|
||||
out_label = dec_trans(padding_label, max_speaker_num, mapping_dict)
|
||||
out_label = torch.from_numpy(out_label)
|
||||
return out_label
|
||||
|
||||
|
||||
def generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=3):
|
||||
perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32)
|
||||
perms = torch.from_numpy(perms).to(label.device).to(torch.int64)
|
||||
perm_labels = [label[:, perm] for perm in perms]
|
||||
perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).
|
||||
to(perm_label.device, non_blocking=True) for perm_label in perm_labels]
|
||||
return perm_labels, perm_pse_labels
|
||||
|
||||
|
||||
def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3):
|
||||
perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num,
|
||||
max_olp_speaker_num=max_olp_speaker_num)
|
||||
losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
|
||||
for perm_pse_label in perm_pse_labels]
|
||||
loss = torch.stack(losses)
|
||||
min_index = torch.argmin(loss)
|
||||
selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index]
|
||||
return selected_perm_label, selected_pse_label
|
||||
159
funasr_local/modules/eend_ola/utils/report.py
Normal file
159
funasr_local/modules/eend_ola/utils/report.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import copy
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
from eend.utils.power import create_powerlabel
|
||||
from itertools import combinations
|
||||
|
||||
metrics = [
|
||||
('diarization_error', 'speaker_scored', 'DER'),
|
||||
('speech_miss', 'speech_scored', 'SAD_MR'),
|
||||
('speech_falarm', 'speech_scored', 'SAD_FR'),
|
||||
('speaker_miss', 'speaker_scored', 'MI'),
|
||||
('speaker_falarm', 'speaker_scored', 'FA'),
|
||||
('speaker_error', 'speaker_scored', 'CF'),
|
||||
('correct', 'frames', 'accuracy')
|
||||
]
|
||||
|
||||
|
||||
def recover_prediction(y, n_speaker):
|
||||
if n_speaker <= 1:
|
||||
return y
|
||||
elif n_speaker == 2:
|
||||
com_index = torch.from_numpy(
|
||||
np.array(list(combinations(np.arange(n_speaker), 2)))).to(
|
||||
y.dtype)
|
||||
num_coms = com_index.shape[0]
|
||||
y_single = y[:, :-num_coms]
|
||||
y_olp = y[:, -num_coms:]
|
||||
olp_map_index = torch.where(y_olp > 0.5)
|
||||
olp_map_index = torch.stack(olp_map_index, dim=1)
|
||||
com_map_index = com_index[olp_map_index[:, -1]]
|
||||
speaker_map_index = torch.from_numpy(np.array(com_map_index)).view(-1).to(torch.int64)
|
||||
frame_map_index = olp_map_index[:, 0][:, None].repeat([1, 2]).view(-1).to(
|
||||
torch.int64)
|
||||
y_single[frame_map_index] = 0
|
||||
y_single[frame_map_index, speaker_map_index] = 1
|
||||
return y_single
|
||||
else:
|
||||
olp2_com_index = torch.from_numpy(np.array(list(combinations(np.arange(n_speaker), 2)))).to(y.dtype)
|
||||
olp2_num_coms = olp2_com_index.shape[0]
|
||||
olp3_com_index = torch.from_numpy(np.array(list(combinations(np.arange(n_speaker), 3)))).to(y.dtype)
|
||||
olp3_num_coms = olp3_com_index.shape[0]
|
||||
y_single = y[:, :n_speaker]
|
||||
y_olp2 = y[:, n_speaker:n_speaker + olp2_num_coms]
|
||||
y_olp3 = y[:, -olp3_num_coms:]
|
||||
|
||||
olp3_map_index = torch.where(y_olp3 > 0.5)
|
||||
olp3_map_index = torch.stack(olp3_map_index, dim=1)
|
||||
olp3_com_map_index = olp3_com_index[olp3_map_index[:, -1]]
|
||||
olp3_speaker_map_index = torch.from_numpy(np.array(olp3_com_map_index)).view(-1).to(torch.int64)
|
||||
olp3_frame_map_index = olp3_map_index[:, 0][:, None].repeat([1, 3]).view(-1).to(torch.int64)
|
||||
y_single[olp3_frame_map_index] = 0
|
||||
y_single[olp3_frame_map_index, olp3_speaker_map_index] = 1
|
||||
y_olp2[olp3_frame_map_index] = 0
|
||||
|
||||
olp2_map_index = torch.where(y_olp2 > 0.5)
|
||||
olp2_map_index = torch.stack(olp2_map_index, dim=1)
|
||||
olp2_com_map_index = olp2_com_index[olp2_map_index[:, -1]]
|
||||
olp2_speaker_map_index = torch.from_numpy(np.array(olp2_com_map_index)).view(-1).to(torch.int64)
|
||||
olp2_frame_map_index = olp2_map_index[:, 0][:, None].repeat([1, 2]).view(-1).to(torch.int64)
|
||||
y_single[olp2_frame_map_index] = 0
|
||||
y_single[olp2_frame_map_index, olp2_speaker_map_index] = 1
|
||||
return y_single
|
||||
|
||||
|
||||
class PowerReporter():
|
||||
def __init__(self, valid_data_loader, mapping_dict, max_n_speaker):
|
||||
valid_data_loader_cp = copy.deepcopy(valid_data_loader)
|
||||
self.valid_data_loader = valid_data_loader_cp
|
||||
del valid_data_loader
|
||||
self.mapping_dict = mapping_dict
|
||||
self.max_n_speaker = max_n_speaker
|
||||
|
||||
def report(self, model, eidx, device):
|
||||
self.report_val(model, eidx, device)
|
||||
|
||||
def report_val(self, model, eidx, device):
|
||||
model.eval()
|
||||
ud_valid_start = time.time()
|
||||
valid_res, valid_loss, stats_keys, vad_valid_accuracy = self.report_core(model, self.valid_data_loader, device)
|
||||
|
||||
# Epoch Display
|
||||
valid_der = valid_res['diarization_error'] / valid_res['speaker_scored']
|
||||
valid_accuracy = valid_res['correct'].to(torch.float32) / valid_res['frames'] * 100
|
||||
vad_valid_accuracy = vad_valid_accuracy * 100
|
||||
print('Epoch ', eidx + 1, 'Valid Loss ', valid_loss, 'Valid_DER %.5f' % valid_der,
|
||||
'Valid_Accuracy %.5f%% ' % valid_accuracy, 'VAD_Valid_Accuracy %.5f%% ' % vad_valid_accuracy)
|
||||
ud_valid = (time.time() - ud_valid_start) / 60.
|
||||
print('Valid cost time ... ', ud_valid)
|
||||
|
||||
def inv_mapping_func(self, label, mapping_dict):
|
||||
if not isinstance(label, int):
|
||||
label = int(label)
|
||||
if label in mapping_dict['label2dec'].keys():
|
||||
num = mapping_dict['label2dec'][label]
|
||||
else:
|
||||
num = -1
|
||||
return num
|
||||
|
||||
def report_core(self, model, data_loader, device):
|
||||
res = {}
|
||||
for item in metrics:
|
||||
res[item[0]] = 0.
|
||||
res[item[1]] = 0.
|
||||
with torch.no_grad():
|
||||
loss_s = 0.
|
||||
uidx = 0
|
||||
for xs, ts, orders in data_loader:
|
||||
xs = [x.to(device) for x in xs]
|
||||
ts = [t.to(device) for t in ts]
|
||||
orders = [o.to(device) for o in orders]
|
||||
loss, pit_loss, mpit_loss, att_loss, ys, logits, labels, attractors = model(xs, ts, orders)
|
||||
loss_s += loss.item()
|
||||
uidx += 1
|
||||
|
||||
for logit, t, att in zip(logits, labels, attractors):
|
||||
pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, )
|
||||
oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
|
||||
for i in oov_index:
|
||||
if i > 0:
|
||||
pred[i] = pred[i - 1]
|
||||
else:
|
||||
pred[i] = 0
|
||||
pred = [self.inv_mapping_func(i, self.mapping_dict) for i in pred]
|
||||
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
|
||||
decisions = torch.from_numpy(
|
||||
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(att.device).to(
|
||||
torch.float32)
|
||||
decisions = decisions[:, :att.shape[0]]
|
||||
|
||||
stats = self.calc_diarization_error(decisions, t)
|
||||
res['speaker_scored'] += stats['speaker_scored']
|
||||
res['speech_scored'] += stats['speech_scored']
|
||||
res['frames'] += stats['frames']
|
||||
for item in metrics:
|
||||
res[item[0]] += stats[item[0]]
|
||||
loss_s /= uidx
|
||||
vad_acc = 0
|
||||
|
||||
return res, loss_s, stats.keys(), vad_acc
|
||||
|
||||
def calc_diarization_error(self, decisions, label, label_delay=0):
|
||||
label = label[:len(label) - label_delay, ...]
|
||||
n_ref = torch.sum(label, dim=-1)
|
||||
n_sys = torch.sum(decisions, dim=-1)
|
||||
res = {}
|
||||
res['speech_scored'] = torch.sum(n_ref > 0)
|
||||
res['speech_miss'] = torch.sum((n_ref > 0) & (n_sys == 0))
|
||||
res['speech_falarm'] = torch.sum((n_ref == 0) & (n_sys > 0))
|
||||
res['speaker_scored'] = torch.sum(n_ref)
|
||||
res['speaker_miss'] = torch.sum(torch.max(n_ref - n_sys, torch.zeros_like(n_ref)))
|
||||
res['speaker_falarm'] = torch.sum(torch.max(n_sys - n_ref, torch.zeros_like(n_ref)))
|
||||
n_map = torch.sum(((label == 1) & (decisions == 1)), dim=-1).to(torch.float32)
|
||||
res['speaker_error'] = torch.sum(torch.min(n_ref, n_sys) - n_map)
|
||||
res['correct'] = torch.sum(label == decisions) / label.shape[1]
|
||||
res['diarization_error'] = (
|
||||
res['speaker_miss'] + res['speaker_falarm'] + res['speaker_error'])
|
||||
res['frames'] = len(label)
|
||||
return res
|
||||
511
funasr_local/modules/embedding.py
Normal file
511
funasr_local/modules/embedding.py
Normal file
@@ -0,0 +1,511 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positional Encoding Module."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""Perform pre-hook in load_state_dict for backward compatibility.
|
||||
|
||||
Note:
|
||||
We saved self.pe until v.0.5.2 but we have omitted it later.
|
||||
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
||||
|
||||
"""
|
||||
k = prefix + "pe"
|
||||
if k in state_dict:
|
||||
state_dict.pop(k)
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
reverse (bool): Whether to reverse the input position. Only for
|
||||
the class LegacyRelPositionalEncoding. We remove it in the current
|
||||
class RelPositionalEncoding.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.reverse = reverse
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class ScaledPositionalEncoding(PositionalEncoding):
|
||||
"""Scaled positional encoding module.
|
||||
|
||||
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class."""
|
||||
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset parameters."""
|
||||
self.alpha.data = torch.tensor(1.0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x + self.alpha * self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class LearnableFourierPosEnc(torch.nn.Module):
|
||||
"""Learnable Fourier Features for Positional Encoding.
|
||||
|
||||
See https://arxiv.org/pdf/2106.02795.pdf
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
gamma (float): init parameter for the positional kernel variance
|
||||
see https://arxiv.org/pdf/2106.02795.pdf.
|
||||
apply_scaling (bool): Whether to scale the input before adding the pos encoding.
|
||||
hidden_dim (int): if not None, we modulate the pos encodings with
|
||||
an MLP whose hidden layer has hidden_dim neurons.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
dropout_rate=0.0,
|
||||
max_len=5000,
|
||||
gamma=1.0,
|
||||
apply_scaling=False,
|
||||
hidden_dim=None,
|
||||
):
|
||||
"""Initialize class."""
|
||||
super(LearnableFourierPosEnc, self).__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
|
||||
if apply_scaling:
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
else:
|
||||
self.xscale = 1.0
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.max_len = max_len
|
||||
|
||||
self.gamma = gamma
|
||||
if self.gamma is None:
|
||||
self.gamma = self.d_model // 2
|
||||
|
||||
assert (
|
||||
d_model % 2 == 0
|
||||
), "d_model should be divisible by two in order to use this layer."
|
||||
self.w_r = torch.nn.Parameter(torch.empty(1, d_model // 2))
|
||||
self._reset() # init the weights
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
if self.hidden_dim is not None:
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(d_model, hidden_dim),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(hidden_dim, d_model),
|
||||
)
|
||||
|
||||
def _reset(self):
|
||||
self.w_r.data = torch.normal(
|
||||
0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2)
|
||||
)
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
position_v = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1).to(x)
|
||||
|
||||
cosine = torch.cos(torch.matmul(position_v, self.w_r))
|
||||
sine = torch.sin(torch.matmul(position_v, self.w_r))
|
||||
pos_enc = torch.cat((cosine, sine), -1)
|
||||
pos_enc /= math.sqrt(self.d_model)
|
||||
|
||||
if self.hidden_dim is None:
|
||||
return pos_enc.unsqueeze(0)
|
||||
else:
|
||||
return self.mlp(pos_enc.unsqueeze(0))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
"""
|
||||
pe = self.extend_pe(x)
|
||||
x = x * self.xscale + pe
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class LegacyRelPositionalEncoding(PositionalEncoding):
|
||||
"""Relative positional encoding module (old version).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class."""
|
||||
super().__init__(
|
||||
d_model=d_model,
|
||||
dropout_rate=dropout_rate,
|
||||
max_len=max_len,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[:, : x.size(1)]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module (new implementation).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in https://arxiv.org/abs/1901.02860
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class StreamPositionalEncoding(torch.nn.Module):
|
||||
"""Streaming Positional encoding.
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(StreamPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.tmp = torch.tensor(0.0).expand(1, max_len)
|
||||
self.extend_pe(self.tmp.size(1), self.tmp.device, self.tmp.dtype)
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self, length, device, dtype):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= length:
|
||||
if self.pe.dtype != dtype or self.pe.device != device:
|
||||
self.pe = self.pe.to(dtype=dtype, device=device)
|
||||
return
|
||||
pe = torch.zeros(length, self.d_model)
|
||||
position = torch.arange(0, length, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, start_idx: int = 0):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x.size(1) + start_idx, x.device, x.dtype)
|
||||
x = x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
class SinusoidalPositionEncoder(torch.nn.Module):
|
||||
'''
|
||||
|
||||
'''
|
||||
def __int__(self, d_model=80, dropout_rate=0.1):
|
||||
pass
|
||||
|
||||
def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
|
||||
batch_size = positions.size(0)
|
||||
positions = positions.type(dtype)
|
||||
log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
|
||||
inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
|
||||
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
|
||||
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
|
||||
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
||||
return encoding.type(dtype)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, timesteps, input_dim = x.size()
|
||||
positions = torch.arange(1, timesteps+1)[None, :]
|
||||
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
||||
|
||||
return x + position_encoding
|
||||
|
||||
class StreamSinusoidalPositionEncoder(torch.nn.Module):
|
||||
'''
|
||||
|
||||
'''
|
||||
def __int__(self, d_model=80, dropout_rate=0.1):
|
||||
pass
|
||||
|
||||
def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
|
||||
batch_size = positions.size(0)
|
||||
positions = positions.type(dtype)
|
||||
log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
|
||||
inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
|
||||
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
|
||||
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
|
||||
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
||||
return encoding.type(dtype)
|
||||
|
||||
def forward(self, x, cache=None):
|
||||
batch_size, timesteps, input_dim = x.size()
|
||||
start_idx = 0
|
||||
if cache is not None:
|
||||
start_idx = cache["start_idx"]
|
||||
cache["start_idx"] += timesteps
|
||||
positions = torch.arange(1, timesteps+start_idx+1)[None, :]
|
||||
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
||||
return x + position_encoding[:, start_idx: start_idx + timesteps]
|
||||
|
||||
class StreamingRelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding.
|
||||
Args:
|
||||
size: Module size.
|
||||
max_len: Maximum input length.
|
||||
dropout_rate: Dropout rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
|
||||
) -> None:
|
||||
"""Construct a RelativePositionalEncoding object."""
|
||||
super().__init__()
|
||||
|
||||
self.size = size
|
||||
|
||||
self.pe = None
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
|
||||
"""Reset positional encoding.
|
||||
Args:
|
||||
x: Input sequences. (B, T, ?)
|
||||
left_context: Number of frames in left context.
|
||||
"""
|
||||
time1 = x.size(1) + left_context
|
||||
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= time1 * 2 - 1:
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(device=x.device, dtype=x.dtype)
|
||||
return
|
||||
|
||||
pe_positive = torch.zeros(time1, self.size)
|
||||
pe_negative = torch.zeros(time1, self.size)
|
||||
|
||||
position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.size, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.size)
|
||||
)
|
||||
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
|
||||
self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
|
||||
dtype=x.dtype, device=x.device
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
|
||||
"""Compute positional encoding.
|
||||
Args:
|
||||
x: Input sequences. (B, T, ?)
|
||||
left_context: Number of frames in left context.
|
||||
Returns:
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
|
||||
"""
|
||||
self.extend_pe(x, left_context=left_context)
|
||||
|
||||
time1 = x.size(1) + left_context
|
||||
|
||||
pos_enc = self.pe[
|
||||
:, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
|
||||
]
|
||||
pos_enc = self.dropout(pos_enc)
|
||||
|
||||
return pos_enc
|
||||
1
funasr_local/modules/frontends/__init__.py
Normal file
1
funasr_local/modules/frontends/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Initialize sub package."""
|
||||
84
funasr_local/modules/frontends/beamformer.py
Normal file
84
funasr_local/modules/frontends/beamformer.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch_complex import functional as FC
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
|
||||
def get_power_spectral_density_matrix(
|
||||
xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
|
||||
) -> ComplexTensor:
|
||||
"""Return cross-channel power spectral density (PSD) matrix
|
||||
|
||||
Args:
|
||||
xs (ComplexTensor): (..., F, C, T)
|
||||
mask (torch.Tensor): (..., F, C, T)
|
||||
normalization (bool):
|
||||
eps (float):
|
||||
Returns
|
||||
psd (ComplexTensor): (..., F, C, C)
|
||||
|
||||
"""
|
||||
# outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
|
||||
psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
|
||||
|
||||
# Averaging mask along C: (..., C, T) -> (..., T)
|
||||
mask = mask.mean(dim=-2)
|
||||
|
||||
# Normalized mask along T: (..., T)
|
||||
if normalization:
|
||||
# If assuming the tensor is padded with zero, the summation along
|
||||
# the time axis is same regardless of the padding length.
|
||||
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
|
||||
|
||||
# psd: (..., T, C, C)
|
||||
psd = psd_Y * mask[..., None, None]
|
||||
# (..., T, C, C) -> (..., C, C)
|
||||
psd = psd.sum(dim=-3)
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def get_mvdr_vector(
|
||||
psd_s: ComplexTensor,
|
||||
psd_n: ComplexTensor,
|
||||
reference_vector: torch.Tensor,
|
||||
eps: float = 1e-15,
|
||||
) -> ComplexTensor:
|
||||
"""Return the MVDR(Minimum Variance Distortionless Response) vector:
|
||||
|
||||
h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
|
||||
|
||||
Reference:
|
||||
On optimal frequency-domain multichannel linear filtering
|
||||
for noise reduction; M. Souden et al., 2010;
|
||||
https://ieeexplore.ieee.org/document/5089420
|
||||
|
||||
Args:
|
||||
psd_s (ComplexTensor): (..., F, C, C)
|
||||
psd_n (ComplexTensor): (..., F, C, C)
|
||||
reference_vector (torch.Tensor): (..., C)
|
||||
eps (float):
|
||||
Returns:
|
||||
beamform_vector (ComplexTensor)r: (..., F, C)
|
||||
"""
|
||||
# Add eps
|
||||
C = psd_n.size(-1)
|
||||
eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
|
||||
shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
|
||||
eye = eye.view(*shape)
|
||||
psd_n += eps * eye
|
||||
|
||||
# numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
|
||||
numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
|
||||
# ws: (..., C, C) / (...,) -> (..., C, C)
|
||||
ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
|
||||
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
|
||||
beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
|
||||
return beamform_vector
|
||||
|
||||
|
||||
def apply_beamforming_vector(
|
||||
beamform_vector: ComplexTensor, mix: ComplexTensor
|
||||
) -> ComplexTensor:
|
||||
# (..., C) x (..., C, T) -> (..., T)
|
||||
es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
|
||||
return es
|
||||
172
funasr_local/modules/frontends/dnn_beamformer.py
Normal file
172
funasr_local/modules/frontends/dnn_beamformer.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""DNN beamformer module."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from funasr_local.modules.frontends.beamformer import apply_beamforming_vector
|
||||
from funasr_local.modules.frontends.beamformer import get_mvdr_vector
|
||||
from funasr_local.modules.frontends.beamformer import (
|
||||
get_power_spectral_density_matrix, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.frontends.mask_estimator import MaskEstimator
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
|
||||
class DNN_Beamformer(torch.nn.Module):
|
||||
"""DNN mask based Beamformer
|
||||
|
||||
Citation:
|
||||
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
|
||||
https://arxiv.org/abs/1703.04783
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bidim,
|
||||
btype="blstmp",
|
||||
blayers=3,
|
||||
bunits=300,
|
||||
bprojs=320,
|
||||
bnmask=2,
|
||||
dropout_rate=0.0,
|
||||
badim=320,
|
||||
ref_channel: int = -1,
|
||||
beamformer_type="mvdr",
|
||||
):
|
||||
super().__init__()
|
||||
self.mask = MaskEstimator(
|
||||
btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
|
||||
)
|
||||
self.ref = AttentionReference(bidim, badim)
|
||||
self.ref_channel = ref_channel
|
||||
|
||||
self.nmask = bnmask
|
||||
|
||||
if beamformer_type != "mvdr":
|
||||
raise ValueError(
|
||||
"Not supporting beamformer_type={}".format(beamformer_type)
|
||||
)
|
||||
self.beamformer_type = beamformer_type
|
||||
|
||||
def forward(
|
||||
self, data: ComplexTensor, ilens: torch.LongTensor
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
||||
"""The forward function
|
||||
|
||||
Notation:
|
||||
B: Batch
|
||||
C: Channel
|
||||
T: Time or Sequence length
|
||||
F: Freq
|
||||
|
||||
Args:
|
||||
data (ComplexTensor): (B, T, C, F)
|
||||
ilens (torch.Tensor): (B,)
|
||||
Returns:
|
||||
enhanced (ComplexTensor): (B, T, F)
|
||||
ilens (torch.Tensor): (B,)
|
||||
|
||||
"""
|
||||
|
||||
def apply_beamforming(data, ilens, psd_speech, psd_noise):
|
||||
# u: (B, C)
|
||||
if self.ref_channel < 0:
|
||||
u, _ = self.ref(psd_speech, ilens)
|
||||
else:
|
||||
# (optional) Create onehot vector for fixed reference microphone
|
||||
u = torch.zeros(
|
||||
*(data.size()[:-3] + (data.size(-2),)), device=data.device
|
||||
)
|
||||
u[..., self.ref_channel].fill_(1)
|
||||
|
||||
ws = get_mvdr_vector(psd_speech, psd_noise, u)
|
||||
enhanced = apply_beamforming_vector(ws, data)
|
||||
|
||||
return enhanced, ws
|
||||
|
||||
# data (B, T, C, F) -> (B, F, C, T)
|
||||
data = data.permute(0, 3, 2, 1)
|
||||
|
||||
# mask: (B, F, C, T)
|
||||
masks, _ = self.mask(data, ilens)
|
||||
assert self.nmask == len(masks)
|
||||
|
||||
if self.nmask == 2: # (mask_speech, mask_noise)
|
||||
mask_speech, mask_noise = masks
|
||||
|
||||
psd_speech = get_power_spectral_density_matrix(data, mask_speech)
|
||||
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
|
||||
|
||||
enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
|
||||
|
||||
# (..., F, T) -> (..., T, F)
|
||||
enhanced = enhanced.transpose(-1, -2)
|
||||
mask_speech = mask_speech.transpose(-1, -3)
|
||||
else: # multi-speaker case: (mask_speech1, ..., mask_noise)
|
||||
mask_speech = list(masks[:-1])
|
||||
mask_noise = masks[-1]
|
||||
|
||||
psd_speeches = [
|
||||
get_power_spectral_density_matrix(data, mask) for mask in mask_speech
|
||||
]
|
||||
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
|
||||
|
||||
enhanced = []
|
||||
ws = []
|
||||
for i in range(self.nmask - 1):
|
||||
psd_speech = psd_speeches.pop(i)
|
||||
# treat all other speakers' psd_speech as noises
|
||||
enh, w = apply_beamforming(
|
||||
data, ilens, psd_speech, sum(psd_speeches) + psd_noise
|
||||
)
|
||||
psd_speeches.insert(i, psd_speech)
|
||||
|
||||
# (..., F, T) -> (..., T, F)
|
||||
enh = enh.transpose(-1, -2)
|
||||
mask_speech[i] = mask_speech[i].transpose(-1, -3)
|
||||
|
||||
enhanced.append(enh)
|
||||
ws.append(w)
|
||||
|
||||
return enhanced, ilens, mask_speech
|
||||
|
||||
|
||||
class AttentionReference(torch.nn.Module):
|
||||
def __init__(self, bidim, att_dim):
|
||||
super().__init__()
|
||||
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
|
||||
self.gvec = torch.nn.Linear(att_dim, 1)
|
||||
|
||||
def forward(
|
||||
self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
"""The forward function
|
||||
|
||||
Args:
|
||||
psd_in (ComplexTensor): (B, F, C, C)
|
||||
ilens (torch.Tensor): (B,)
|
||||
scaling (float):
|
||||
Returns:
|
||||
u (torch.Tensor): (B, C)
|
||||
ilens (torch.Tensor): (B,)
|
||||
"""
|
||||
B, _, C = psd_in.size()[:3]
|
||||
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
|
||||
# psd_in: (B, F, C, C)
|
||||
psd = psd_in.masked_fill(
|
||||
torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
|
||||
)
|
||||
# psd: (B, F, C, C) -> (B, C, F)
|
||||
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
|
||||
|
||||
# Calculate amplitude
|
||||
psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
|
||||
|
||||
# (B, C, F) -> (B, C, F2)
|
||||
mlp_psd = self.mlp_psd(psd_feat)
|
||||
# (B, C, F2) -> (B, C, 1) -> (B, C)
|
||||
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
|
||||
u = F.softmax(scaling * e, dim=-1)
|
||||
return u, ilens
|
||||
93
funasr_local/modules/frontends/dnn_wpe.py
Normal file
93
funasr_local/modules/frontends/dnn_wpe.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Tuple
|
||||
|
||||
from pytorch_wpe import wpe_one_iteration
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr_local.modules.frontends.mask_estimator import MaskEstimator
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class DNN_WPE(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
wtype: str = "blstmp",
|
||||
widim: int = 257,
|
||||
wlayers: int = 3,
|
||||
wunits: int = 300,
|
||||
wprojs: int = 320,
|
||||
dropout_rate: float = 0.0,
|
||||
taps: int = 5,
|
||||
delay: int = 3,
|
||||
use_dnn_mask: bool = True,
|
||||
iterations: int = 1,
|
||||
normalization: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.iterations = iterations
|
||||
self.taps = taps
|
||||
self.delay = delay
|
||||
|
||||
self.normalization = normalization
|
||||
self.use_dnn_mask = use_dnn_mask
|
||||
|
||||
self.inverse_power = True
|
||||
|
||||
if self.use_dnn_mask:
|
||||
self.mask_est = MaskEstimator(
|
||||
wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, data: ComplexTensor, ilens: torch.LongTensor
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
||||
"""The forward function
|
||||
|
||||
Notation:
|
||||
B: Batch
|
||||
C: Channel
|
||||
T: Time or Sequence length
|
||||
F: Freq or Some dimension of the feature vector
|
||||
|
||||
Args:
|
||||
data: (B, C, T, F)
|
||||
ilens: (B,)
|
||||
Returns:
|
||||
data: (B, C, T, F)
|
||||
ilens: (B,)
|
||||
"""
|
||||
# (B, T, C, F) -> (B, F, C, T)
|
||||
enhanced = data = data.permute(0, 3, 2, 1)
|
||||
mask = None
|
||||
|
||||
for i in range(self.iterations):
|
||||
# Calculate power: (..., C, T)
|
||||
power = enhanced.real**2 + enhanced.imag**2
|
||||
if i == 0 and self.use_dnn_mask:
|
||||
# mask: (B, F, C, T)
|
||||
(mask,), _ = self.mask_est(enhanced, ilens)
|
||||
if self.normalization:
|
||||
# Normalize along T
|
||||
mask = mask / mask.sum(dim=-1)[..., None]
|
||||
# (..., C, T) * (..., C, T) -> (..., C, T)
|
||||
power = power * mask
|
||||
|
||||
# Averaging along the channel axis: (..., C, T) -> (..., T)
|
||||
power = power.mean(dim=-2)
|
||||
|
||||
# enhanced: (..., C, T) -> (..., C, T)
|
||||
enhanced = wpe_one_iteration(
|
||||
data.contiguous(),
|
||||
power,
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
inverse_power=self.inverse_power,
|
||||
)
|
||||
|
||||
enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
|
||||
|
||||
# (B, F, C, T) -> (B, T, C, F)
|
||||
enhanced = enhanced.permute(0, 3, 2, 1)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -3)
|
||||
return enhanced, ilens, mask
|
||||
263
funasr_local/modules/frontends/feature_transform.py
Normal file
263
funasr_local/modules/frontends/feature_transform.py
Normal file
@@ -0,0 +1,263 @@
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class FeatureTransform(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# Mel options,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = 0.0,
|
||||
fmax: float = None,
|
||||
# Normalization
|
||||
stats_file: str = None,
|
||||
apply_uttmvn: bool = True,
|
||||
uttmvn_norm_means: bool = True,
|
||||
uttmvn_norm_vars: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_uttmvn = apply_uttmvn
|
||||
|
||||
self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
self.stats_file = stats_file
|
||||
if stats_file is not None:
|
||||
self.global_mvn = GlobalMVN(stats_file)
|
||||
else:
|
||||
self.global_mvn = None
|
||||
|
||||
if self.apply_uttmvn is not None:
|
||||
self.uttmvn = UtteranceMVN(
|
||||
norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
|
||||
)
|
||||
else:
|
||||
self.uttmvn = None
|
||||
|
||||
def forward(
|
||||
self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
# (B, T, F) or (B, T, C, F)
|
||||
if x.dim() not in (3, 4):
|
||||
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
|
||||
if not torch.is_tensor(ilens):
|
||||
ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
|
||||
|
||||
if x.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(x.size(2))
|
||||
h = x[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
h = x[:, :, 0, :]
|
||||
else:
|
||||
h = x
|
||||
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
h = h.real**2 + h.imag**2
|
||||
|
||||
h, _ = self.logmel(h, ilens)
|
||||
if self.stats_file is not None:
|
||||
h, _ = self.global_mvn(h, ilens)
|
||||
if self.apply_uttmvn:
|
||||
h, _ = self.uttmvn(h, ilens)
|
||||
|
||||
return h, ilens
|
||||
|
||||
|
||||
class LogMel(torch.nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
norm: {None, 1, np.inf} [scalar]
|
||||
if 1, divide the triangular mel weights by the width of the mel band
|
||||
(area normalization). Otherwise, leave all the triangles aiming for
|
||||
a peak value of 1.0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = 0.0,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
norm=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
_mel_options = dict(
|
||||
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self, feat: torch.Tensor, ilens: torch.LongTensor
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
|
||||
logmel_feat = (mel_feat + 1e-20).log()
|
||||
# Zero padding
|
||||
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
|
||||
return logmel_feat, ilens
|
||||
|
||||
|
||||
class GlobalMVN(torch.nn.Module):
|
||||
"""Apply global mean and variance normalization
|
||||
|
||||
Args:
|
||||
stats_file(str): npy file of 1-dim array or text file.
|
||||
From the _first element to
|
||||
the {(len(array) - 1) / 2}th element are treated as
|
||||
the sum of features,
|
||||
and the rest excluding the last elements are
|
||||
treated as the sum of the square value of features,
|
||||
and the last elements eqauls to the number of samples.
|
||||
std_floor(float):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats_file: str,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = True,
|
||||
eps: float = 1.0e-20,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
|
||||
self.stats_file = stats_file
|
||||
stats = np.load(stats_file)
|
||||
|
||||
stats = stats.astype(float)
|
||||
assert (len(stats) - 1) % 2 == 0, stats.shape
|
||||
|
||||
count = stats.flatten()[-1]
|
||||
mean = stats[: (len(stats) - 1) // 2] / count
|
||||
var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
|
||||
std = np.maximum(np.sqrt(var), eps)
|
||||
|
||||
self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
|
||||
self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"stats_file={self.stats_file}, "
|
||||
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.LongTensor
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
# feat: (B, T, D)
|
||||
if self.norm_means:
|
||||
x += self.bias.type_as(x)
|
||||
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
|
||||
|
||||
if self.norm_vars:
|
||||
x *= self.scale.type_as(x)
|
||||
return x, ilens
|
||||
|
||||
|
||||
class UtteranceMVN(torch.nn.Module):
|
||||
def __init__(
|
||||
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.eps = eps
|
||||
|
||||
def extra_repr(self):
|
||||
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.LongTensor
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
return utterance_mvn(
|
||||
x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
|
||||
)
|
||||
|
||||
|
||||
def utterance_mvn(
|
||||
x: torch.Tensor,
|
||||
ilens: torch.LongTensor,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = False,
|
||||
eps: float = 1.0e-20,
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
"""Apply utterance mean and variance normalization
|
||||
|
||||
Args:
|
||||
x: (B, T, D), assumed zero padded
|
||||
ilens: (B, T, D)
|
||||
norm_means:
|
||||
norm_vars:
|
||||
eps:
|
||||
|
||||
"""
|
||||
ilens_ = ilens.type_as(x)
|
||||
# mean: (B, D)
|
||||
mean = x.sum(dim=1) / ilens_[:, None]
|
||||
|
||||
if norm_means:
|
||||
x -= mean[:, None, :]
|
||||
x_ = x
|
||||
else:
|
||||
x_ = x - mean[:, None, :]
|
||||
|
||||
# Zero padding
|
||||
x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
|
||||
if norm_vars:
|
||||
var = x_.pow(2).sum(dim=1) / ilens_[:, None]
|
||||
var = torch.clamp(var, min=eps)
|
||||
x /= var.sqrt()[:, None, :]
|
||||
x_ = x
|
||||
return x_, ilens
|
||||
|
||||
|
||||
def feature_transform_for(args, n_fft):
|
||||
return FeatureTransform(
|
||||
# Mel options,
|
||||
fs=args.fbank_fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=args.n_mels,
|
||||
fmin=args.fbank_fmin,
|
||||
fmax=args.fbank_fmax,
|
||||
# Normalization
|
||||
stats_file=args.stats_file,
|
||||
apply_uttmvn=args.apply_uttmvn,
|
||||
uttmvn_norm_means=args.uttmvn_norm_means,
|
||||
uttmvn_norm_vars=args.uttmvn_norm_vars,
|
||||
)
|
||||
151
funasr_local/modules/frontends/frontend.py
Normal file
151
funasr_local/modules/frontends/frontend.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr_local.modules.frontends.dnn_beamformer import DNN_Beamformer
|
||||
# from funasr_local.modules.frontends.dnn_wpe import DNN_WPE
|
||||
|
||||
|
||||
class Frontend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
# WPE options
|
||||
use_wpe: bool = False,
|
||||
wtype: str = "blstmp",
|
||||
wlayers: int = 3,
|
||||
wunits: int = 300,
|
||||
wprojs: int = 320,
|
||||
wdropout_rate: float = 0.0,
|
||||
taps: int = 5,
|
||||
delay: int = 3,
|
||||
use_dnn_mask_for_wpe: bool = True,
|
||||
# Beamformer options
|
||||
use_beamformer: bool = False,
|
||||
btype: str = "blstmp",
|
||||
blayers: int = 3,
|
||||
bunits: int = 300,
|
||||
bprojs: int = 320,
|
||||
bnmask: int = 2,
|
||||
badim: int = 320,
|
||||
ref_channel: int = -1,
|
||||
bdropout_rate=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_beamformer = use_beamformer
|
||||
self.use_wpe = use_wpe
|
||||
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
|
||||
# use frontend for all the data,
|
||||
# e.g. in the case of multi-speaker speech separation
|
||||
self.use_frontend_for_all = bnmask > 2
|
||||
|
||||
if self.use_wpe:
|
||||
if self.use_dnn_mask_for_wpe:
|
||||
# Use DNN for power estimation
|
||||
# (Not observed significant gains)
|
||||
iterations = 1
|
||||
else:
|
||||
# Performing as conventional WPE, without DNN Estimator
|
||||
iterations = 2
|
||||
|
||||
self.wpe = DNN_WPE(
|
||||
wtype=wtype,
|
||||
widim=idim,
|
||||
wunits=wunits,
|
||||
wprojs=wprojs,
|
||||
wlayers=wlayers,
|
||||
taps=taps,
|
||||
delay=delay,
|
||||
dropout_rate=wdropout_rate,
|
||||
iterations=iterations,
|
||||
use_dnn_mask=use_dnn_mask_for_wpe,
|
||||
)
|
||||
else:
|
||||
self.wpe = None
|
||||
|
||||
if self.use_beamformer:
|
||||
self.beamformer = DNN_Beamformer(
|
||||
btype=btype,
|
||||
bidim=idim,
|
||||
bunits=bunits,
|
||||
bprojs=bprojs,
|
||||
blayers=blayers,
|
||||
bnmask=bnmask,
|
||||
dropout_rate=bdropout_rate,
|
||||
badim=badim,
|
||||
ref_channel=ref_channel,
|
||||
)
|
||||
else:
|
||||
self.beamformer = None
|
||||
|
||||
def forward(
|
||||
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
|
||||
assert len(x) == len(ilens), (len(x), len(ilens))
|
||||
# (B, T, F) or (B, T, C, F)
|
||||
if x.dim() not in (3, 4):
|
||||
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
|
||||
if not torch.is_tensor(ilens):
|
||||
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
|
||||
|
||||
mask = None
|
||||
h = x
|
||||
if h.dim() == 4:
|
||||
if self.training:
|
||||
choices = [(False, False)] if not self.use_frontend_for_all else []
|
||||
if self.use_wpe:
|
||||
choices.append((True, False))
|
||||
|
||||
if self.use_beamformer:
|
||||
choices.append((False, True))
|
||||
|
||||
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
|
||||
|
||||
else:
|
||||
use_wpe = self.use_wpe
|
||||
use_beamformer = self.use_beamformer
|
||||
|
||||
# 1. WPE
|
||||
if use_wpe:
|
||||
# h: (B, T, C, F) -> h: (B, T, C, F)
|
||||
h, ilens, mask = self.wpe(h, ilens)
|
||||
|
||||
# 2. Beamformer
|
||||
if use_beamformer:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
h, ilens, mask = self.beamformer(h, ilens)
|
||||
|
||||
return h, ilens, mask
|
||||
|
||||
|
||||
def frontend_for(args, idim):
|
||||
return Frontend(
|
||||
idim=idim,
|
||||
# WPE options
|
||||
use_wpe=args.use_wpe,
|
||||
wtype=args.wtype,
|
||||
wlayers=args.wlayers,
|
||||
wunits=args.wunits,
|
||||
wprojs=args.wprojs,
|
||||
wdropout_rate=args.wdropout_rate,
|
||||
taps=args.wpe_taps,
|
||||
delay=args.wpe_delay,
|
||||
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
|
||||
# Beamformer options
|
||||
use_beamformer=args.use_beamformer,
|
||||
btype=args.btype,
|
||||
blayers=args.blayers,
|
||||
bunits=args.bunits,
|
||||
bprojs=args.bprojs,
|
||||
bnmask=args.bnmask,
|
||||
badim=args.badim,
|
||||
ref_channel=args.ref_channel,
|
||||
bdropout_rate=args.bdropout_rate,
|
||||
)
|
||||
77
funasr_local/modules/frontends/mask_estimator.py
Normal file
77
funasr_local/modules/frontends/mask_estimator.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.rnn.encoders import RNN
|
||||
from funasr_local.modules.rnn.encoders import RNNP
|
||||
|
||||
|
||||
class MaskEstimator(torch.nn.Module):
|
||||
def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
|
||||
super().__init__()
|
||||
subsample = np.ones(layers + 1, dtype=np.int)
|
||||
|
||||
typ = type.lstrip("vgg").rstrip("p")
|
||||
if type[-1] == "p":
|
||||
self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
|
||||
else:
|
||||
self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
|
||||
|
||||
self.type = type
|
||||
self.nmask = nmask
|
||||
self.linears = torch.nn.ModuleList(
|
||||
[torch.nn.Linear(projs, idim) for _ in range(nmask)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, xs: ComplexTensor, ilens: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
|
||||
"""The forward function
|
||||
|
||||
Args:
|
||||
xs: (B, F, C, T)
|
||||
ilens: (B,)
|
||||
Returns:
|
||||
hs (torch.Tensor): The hidden vector (B, F, C, T)
|
||||
masks: A tuple of the masks. (B, F, C, T)
|
||||
ilens: (B,)
|
||||
"""
|
||||
assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
|
||||
_, _, C, input_length = xs.size()
|
||||
# (B, F, C, T) -> (B, C, T, F)
|
||||
xs = xs.permute(0, 2, 3, 1)
|
||||
|
||||
# Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
|
||||
xs = (xs.real**2 + xs.imag**2) ** 0.5
|
||||
# xs: (B, C, T, F) -> xs: (B * C, T, F)
|
||||
xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
|
||||
# ilens: (B,) -> ilens_: (B * C)
|
||||
ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
|
||||
|
||||
# xs: (B * C, T, F) -> xs: (B * C, T, D)
|
||||
xs, _, _ = self.brnn(xs, ilens_)
|
||||
# xs: (B * C, T, D) -> xs: (B, C, T, D)
|
||||
xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
|
||||
|
||||
masks = []
|
||||
for linear in self.linears:
|
||||
# xs: (B, C, T, D) -> mask:(B, C, T, F)
|
||||
mask = linear(xs)
|
||||
|
||||
mask = torch.sigmoid(mask)
|
||||
# Zero padding
|
||||
mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
|
||||
|
||||
# (B, C, T, F) -> (B, F, C, T)
|
||||
mask = mask.permute(0, 3, 1, 2)
|
||||
|
||||
# Take cares of multi gpu cases: If input_length > max(ilens)
|
||||
if mask.size(-1) < input_length:
|
||||
mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
|
||||
masks.append(mask)
|
||||
|
||||
return tuple(masks), ilens
|
||||
42
funasr_local/modules/layer_norm.py
Normal file
42
funasr_local/modules/layer_norm.py
Normal file
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer normalization module."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Layer normalization module.
|
||||
|
||||
Args:
|
||||
nout (int): Output dim size.
|
||||
dim (int): Dimension to be normalized.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, nout, dim=-1):
|
||||
"""Construct an LayerNorm object."""
|
||||
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply layer normalization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Normalized tensor.
|
||||
|
||||
"""
|
||||
if self.dim == -1:
|
||||
return super(LayerNorm, self).forward(x)
|
||||
return (
|
||||
super(LayerNorm, self)
|
||||
.forward(x.transpose(self.dim, -1))
|
||||
.transpose(self.dim, -1)
|
||||
)
|
||||
112
funasr_local/modules/lightconv.py
Normal file
112
funasr_local/modules/lightconv.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Lightweight Convolution Module."""
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
MIN_VALUE = float(numpy.finfo(numpy.float32).min)
|
||||
|
||||
|
||||
class LightweightConvolution(nn.Module):
|
||||
"""Lightweight Convolution layer.
|
||||
|
||||
This implementation is based on
|
||||
https://github.com/pytorch/fairseq/tree/master/fairseq
|
||||
|
||||
Args:
|
||||
wshare (int): the number of kernel of convolution
|
||||
n_feat (int): the number of features
|
||||
dropout_rate (float): dropout_rate
|
||||
kernel_size (int): kernel size (length)
|
||||
use_kernel_mask (bool): Use causal mask or not for convolution kernel
|
||||
use_bias (bool): Use bias term or not.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wshare,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
use_kernel_mask=False,
|
||||
use_bias=False,
|
||||
):
|
||||
"""Construct Lightweight Convolution layer."""
|
||||
super(LightweightConvolution, self).__init__()
|
||||
|
||||
assert n_feat % wshare == 0
|
||||
self.wshare = wshare
|
||||
self.use_kernel_mask = use_kernel_mask
|
||||
self.dropout_rate = dropout_rate
|
||||
self.kernel_size = kernel_size
|
||||
self.padding_size = int(kernel_size / 2)
|
||||
|
||||
# linear -> GLU -> lightconv -> linear
|
||||
self.linear1 = nn.Linear(n_feat, n_feat * 2)
|
||||
self.linear2 = nn.Linear(n_feat, n_feat)
|
||||
self.act = nn.GLU()
|
||||
|
||||
# lightconv related
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)
|
||||
)
|
||||
self.use_bias = use_bias
|
||||
if self.use_bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(n_feat))
|
||||
|
||||
# mask of kernel
|
||||
kernel_mask0 = torch.zeros(self.wshare, int(kernel_size / 2))
|
||||
kernel_mask1 = torch.ones(self.wshare, int(kernel_size / 2 + 1))
|
||||
self.kernel_mask = torch.cat((kernel_mask1, kernel_mask0), dim=-1).unsqueeze(1)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Forward of 'Lightweight Convolution'.
|
||||
|
||||
This function takes query, key and value but uses only query.
|
||||
This is just for compatibility with self-attention layer (attention.py)
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): (batch, time1, d_model) input tensor
|
||||
key (torch.Tensor): (batch, time2, d_model) NOT USED
|
||||
value (torch.Tensor): (batch, time2, d_model) NOT USED
|
||||
mask (torch.Tensor): (batch, time1, time2) mask
|
||||
|
||||
Return:
|
||||
x (torch.Tensor): (batch, time1, d_model) output
|
||||
|
||||
"""
|
||||
# linear -> GLU -> lightconv -> linear
|
||||
x = query
|
||||
B, T, C = x.size()
|
||||
H = self.wshare
|
||||
|
||||
# first liner layer
|
||||
x = self.linear1(x)
|
||||
|
||||
# GLU activation
|
||||
x = self.act(x)
|
||||
|
||||
# lightconv
|
||||
x = x.transpose(1, 2).contiguous().view(-1, H, T) # B x C x T
|
||||
weight = F.dropout(self.weight, self.dropout_rate, training=self.training)
|
||||
if self.use_kernel_mask:
|
||||
self.kernel_mask = self.kernel_mask.to(x.device)
|
||||
weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf"))
|
||||
weight = F.softmax(weight, dim=-1)
|
||||
x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(
|
||||
B, C, T
|
||||
)
|
||||
if self.use_bias:
|
||||
x = x + self.bias.view(1, -1, 1)
|
||||
x = x.transpose(1, 2) # B x T x C
|
||||
|
||||
if mask is not None and not self.use_kernel_mask:
|
||||
mask = mask.transpose(-1, -2)
|
||||
x = x.masked_fill(mask == 0, 0.0)
|
||||
|
||||
# second linear layer
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
124
funasr_local/modules/lightconv2d.py
Normal file
124
funasr_local/modules/lightconv2d.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Lightweight 2-Dimensional Convolution module."""
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
MIN_VALUE = float(numpy.finfo(numpy.float32).min)
|
||||
|
||||
|
||||
class LightweightConvolution2D(nn.Module):
|
||||
"""Lightweight 2-Dimensional Convolution layer.
|
||||
|
||||
This implementation is based on
|
||||
https://github.com/pytorch/fairseq/tree/master/fairseq
|
||||
|
||||
Args:
|
||||
wshare (int): the number of kernel of convolution
|
||||
n_feat (int): the number of features
|
||||
dropout_rate (float): dropout_rate
|
||||
kernel_size (int): kernel size (length)
|
||||
use_kernel_mask (bool): Use causal mask or not for convolution kernel
|
||||
use_bias (bool): Use bias term or not.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wshare,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
use_kernel_mask=False,
|
||||
use_bias=False,
|
||||
):
|
||||
"""Construct Lightweight 2-Dimensional Convolution layer."""
|
||||
super(LightweightConvolution2D, self).__init__()
|
||||
|
||||
assert n_feat % wshare == 0
|
||||
self.wshare = wshare
|
||||
self.use_kernel_mask = use_kernel_mask
|
||||
self.dropout_rate = dropout_rate
|
||||
self.kernel_size = kernel_size
|
||||
self.padding_size = int(kernel_size / 2)
|
||||
|
||||
# linear -> GLU -> lightconv -> linear
|
||||
self.linear1 = nn.Linear(n_feat, n_feat * 2)
|
||||
self.linear2 = nn.Linear(n_feat * 2, n_feat)
|
||||
self.act = nn.GLU()
|
||||
|
||||
# lightconv related
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)
|
||||
)
|
||||
self.weight_f = nn.Parameter(torch.Tensor(1, 1, kernel_size).uniform_(0, 1))
|
||||
self.use_bias = use_bias
|
||||
if self.use_bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(n_feat))
|
||||
|
||||
# mask of kernel
|
||||
kernel_mask0 = torch.zeros(self.wshare, int(kernel_size / 2))
|
||||
kernel_mask1 = torch.ones(self.wshare, int(kernel_size / 2 + 1))
|
||||
self.kernel_mask = torch.cat((kernel_mask1, kernel_mask0), dim=-1).unsqueeze(1)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Forward of 'Lightweight 2-Dimensional Convolution'.
|
||||
|
||||
This function takes query, key and value but uses only query.
|
||||
This is just for compatibility with self-attention layer (attention.py)
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): (batch, time1, d_model) input tensor
|
||||
key (torch.Tensor): (batch, time2, d_model) NOT USED
|
||||
value (torch.Tensor): (batch, time2, d_model) NOT USED
|
||||
mask (torch.Tensor): (batch, time1, time2) mask
|
||||
|
||||
Return:
|
||||
x (torch.Tensor): (batch, time1, d_model) output
|
||||
|
||||
"""
|
||||
# linear -> GLU -> lightconv -> linear
|
||||
x = query
|
||||
B, T, C = x.size()
|
||||
H = self.wshare
|
||||
|
||||
# first liner layer
|
||||
x = self.linear1(x)
|
||||
|
||||
# GLU activation
|
||||
x = self.act(x)
|
||||
|
||||
# convolution along frequency axis
|
||||
weight_f = F.softmax(self.weight_f, dim=-1)
|
||||
weight_f = F.dropout(weight_f, self.dropout_rate, training=self.training)
|
||||
weight_new = torch.zeros(
|
||||
B * T, 1, self.kernel_size, device=x.device, dtype=x.dtype
|
||||
).copy_(weight_f)
|
||||
xf = F.conv1d(
|
||||
x.view(1, B * T, C), weight_new, padding=self.padding_size, groups=B * T
|
||||
).view(B, T, C)
|
||||
|
||||
# lightconv
|
||||
x = x.transpose(1, 2).contiguous().view(-1, H, T) # B x C x T
|
||||
weight = F.dropout(self.weight, self.dropout_rate, training=self.training)
|
||||
if self.use_kernel_mask:
|
||||
self.kernel_mask = self.kernel_mask.to(x.device)
|
||||
weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf"))
|
||||
weight = F.softmax(weight, dim=-1)
|
||||
x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(
|
||||
B, C, T
|
||||
)
|
||||
if self.use_bias:
|
||||
x = x + self.bias.view(1, -1, 1)
|
||||
x = x.transpose(1, 2) # B x T x C
|
||||
x = torch.cat((x, xf), -1) # B x T x Cx2
|
||||
|
||||
if mask is not None and not self.use_kernel_mask:
|
||||
mask = mask.transpose(-1, -2)
|
||||
x = x.masked_fill(mask == 0, 0.0)
|
||||
|
||||
# second linear layer
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
52
funasr_local/modules/mask.py
Normal file
52
funasr_local/modules/mask.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Mask module."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def subsequent_mask(size, device="cpu", dtype=torch.bool):
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
:param int size: size of mask
|
||||
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=dtype)
|
||||
return torch.tril(ret, out=ret)
|
||||
|
||||
|
||||
def target_mask(ys_in_pad, ignore_id):
|
||||
"""Create mask for decoder self-attention.
|
||||
|
||||
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
|
||||
:param int ignore_id: index of padding
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor (B, Lmax, Lmax)
|
||||
"""
|
||||
ys_mask = ys_in_pad != ignore_id
|
||||
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
|
||||
return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool):
|
||||
"""Create mask for decoder self-attention.
|
||||
|
||||
:param int size: size of mask
|
||||
:param int vad_pos: index of vad index
|
||||
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor (B, Lmax, Lmax)
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=dtype)
|
||||
if vad_pos <= 0 or vad_pos >= size:
|
||||
return ret
|
||||
sub_corner = torch.zeros(
|
||||
vad_pos - 1, size - vad_pos, device=device, dtype=dtype)
|
||||
ret[0:vad_pos - 1, vad_pos:] = sub_corner
|
||||
return ret
|
||||
157
funasr_local/modules/multi_layer_conv.py
Normal file
157
funasr_local/modules/multi_layer_conv.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiLayeredConv1d(torch.nn.Module):
|
||||
"""Multi-layered conv1d for Transformer block.
|
||||
|
||||
This is a module of multi-leyered conv1d designed
|
||||
to replace positionwise feed-forward network
|
||||
in Transforner block, which is introduced in
|
||||
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
||||
|
||||
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
||||
https://arxiv.org/pdf/1905.09263.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize MultiLayeredConv1d module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(MultiLayeredConv1d, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Conv1d(
|
||||
hidden_chans,
|
||||
in_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
||||
|
||||
|
||||
class FsmnFeedForward(torch.nn.Module):
|
||||
"""Position-wise feed forward for FSMN blocks.
|
||||
|
||||
This is a module of multi-leyered conv1d designed
|
||||
to replace position-wise feed-forward network
|
||||
in FSMN block.
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
|
||||
"""Initialize FsmnFeedForward module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
out_chans (int): Number of output channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(FsmnFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Conv1d(
|
||||
hidden_chans,
|
||||
out_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
bias=False
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(hidden_chans)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x, ilens=None):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of output tensors (B, T, out_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
|
||||
|
||||
|
||||
class Conv1dLinear(torch.nn.Module):
|
||||
"""Conv1D + Linear for Transformer block.
|
||||
|
||||
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize Conv1dLinear module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(Conv1dLinear, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x))
|
||||
701
funasr_local/modules/nets_utils.py
Normal file
701
funasr_local/modules/nets_utils.py
Normal file
@@ -0,0 +1,701 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(m, x):
|
||||
"""Send tensor into the device of the module.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): Torch module.
|
||||
x (Tensor): Torch tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Torch tensor located in the same place as torch module.
|
||||
|
||||
"""
|
||||
if isinstance(m, torch.nn.Module):
|
||||
device = next(m.parameters()).device
|
||||
elif isinstance(m, torch.Tensor):
|
||||
device = m.device
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
|
||||
)
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def pad_list(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, : xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor.
|
||||
If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor.
|
||||
See the example.
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 0, 1],
|
||||
[0, 0, 0, 1]],
|
||||
[[0, 0, 1, 1],
|
||||
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_pad_mask(lengths, xs, 1)
|
||||
tensor([[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
>>> make_pad_mask(lengths, xs, 2)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if maxlen is None:
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
else:
|
||||
assert xs is None
|
||||
assert maxlen >= int(max(lengths))
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(
|
||||
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
||||
)
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor.
|
||||
If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor.
|
||||
See the example.
|
||||
|
||||
Returns:
|
||||
ByteTensor: mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0],
|
||||
[1, 1, 1, 0]],
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_non_pad_mask(lengths, xs, 1)
|
||||
tensor([[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
>>> make_non_pad_mask(lengths, xs, 2)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
return ~make_pad_mask(lengths, xs, length_dim)
|
||||
|
||||
|
||||
def mask_by_length(xs, lengths, fill=0):
|
||||
"""Mask tensor according to length.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of input tensor (B, `*`).
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
fill (int or float): Value to fill masked part.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of masked input tensor (B, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5]])
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> mask_by_length(x, lengths)
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 0, 0],
|
||||
[1, 2, 0, 0, 0]])
|
||||
|
||||
"""
|
||||
assert xs.size(0) == len(lengths)
|
||||
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||
for i, l in enumerate(lengths):
|
||||
ret[i, :l] = xs[i, :l]
|
||||
return ret
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
|
||||
).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(
|
||||
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
|
||||
)
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def to_torch_tensor(x):
|
||||
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||
|
||||
Args:
|
||||
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||
|
||||
Returns:
|
||||
Tensor or ComplexTensor: Type converted inputs.
|
||||
|
||||
Examples:
|
||||
>>> xs = np.ones(3, dtype=np.float32)
|
||||
>>> xs = to_torch_tensor(xs)
|
||||
tensor([1., 1., 1.])
|
||||
>>> xs = torch.ones(3, 4, 5)
|
||||
>>> assert to_torch_tensor(xs) is xs
|
||||
>>> xs = {'real': xs, 'imag': xs}
|
||||
>>> to_torch_tensor(xs)
|
||||
ComplexTensor(
|
||||
Real:
|
||||
tensor([1., 1., 1.])
|
||||
Imag;
|
||||
tensor([1., 1., 1.])
|
||||
)
|
||||
|
||||
"""
|
||||
# If numpy, change to torch tensor
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype.kind == "c":
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
return ComplexTensor(x)
|
||||
else:
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||
elif isinstance(x, dict):
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
if "real" not in x or "imag" not in x:
|
||||
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||
# Relative importing because of using python3 syntax
|
||||
return ComplexTensor(x["real"], x["imag"])
|
||||
|
||||
# If torch.Tensor, as it is
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x
|
||||
|
||||
else:
|
||||
error = (
|
||||
"x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||
"but got {}".format(type(x))
|
||||
)
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except Exception:
|
||||
# If PY2
|
||||
raise ValueError(error)
|
||||
else:
|
||||
# If PY3
|
||||
if isinstance(x, ComplexTensor):
|
||||
return x
|
||||
else:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_subsample(train_args, mode, arch):
|
||||
"""Parse the subsampling factors from the args for the specified `mode` and `arch`.
|
||||
|
||||
Args:
|
||||
train_args: argument Namespace containing options.
|
||||
mode: one of ('asr', 'mt', 'st')
|
||||
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||
|
||||
Returns:
|
||||
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||
"""
|
||||
if arch == "transformer":
|
||||
return np.array([1])
|
||||
|
||||
elif mode == "mt" and arch == "rnn":
|
||||
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
logging.warning("Subsampling is not performed for machine translation.")
|
||||
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif (
|
||||
(mode == "asr" and arch in ("rnn", "rnn-t"))
|
||||
or (mode == "mt" and arch == "rnn")
|
||||
or (mode == "st" and arch == "rnn")
|
||||
):
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
"Subsampling is not performed for vgg*. "
|
||||
"It is performed in max pooling layers at CNN."
|
||||
)
|
||||
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == "asr" and arch == "rnn_mix":
|
||||
subsample = np.ones(
|
||||
train_args.elayers_sd + train_args.elayers + 1, dtype=np.int
|
||||
)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(
|
||||
min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
|
||||
):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
"Subsampling is not performed for vgg*. "
|
||||
"It is performed in max pooling layers at CNN."
|
||||
)
|
||||
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == "asr" and arch == "rnn_mulenc":
|
||||
subsample_list = []
|
||||
for idx in range(train_args.num_encs):
|
||||
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||
if train_args.etype[idx].endswith("p") and not train_args.etype[
|
||||
idx
|
||||
].startswith("vgg"):
|
||||
ss = train_args.subsample[idx].split("_")
|
||||
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
"Encoder %d: Subsampling is not performed for vgg*. "
|
||||
"It is performed in max pooling layers at CNN.",
|
||||
idx + 1,
|
||||
)
|
||||
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
||||
subsample_list.append(subsample)
|
||||
return subsample_list
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
|
||||
|
||||
|
||||
def rename_state_dict(
|
||||
old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
|
||||
):
|
||||
"""Replace keys of old prefix with new prefix in state dict."""
|
||||
# need this list not to break the dict iterator
|
||||
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||
if len(old_keys) > 0:
|
||||
logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
|
||||
for k in old_keys:
|
||||
v = state_dict.pop(k)
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x):
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def get_activation(act):
|
||||
"""Return activation function."""
|
||||
|
||||
activation_funcs = {
|
||||
"hardtanh": torch.nn.Hardtanh,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"relu": torch.nn.ReLU,
|
||||
"selu": torch.nn.SELU,
|
||||
"swish": Swish,
|
||||
}
|
||||
|
||||
return activation_funcs[act]()
|
||||
|
||||
class TooShortUttError(Exception):
|
||||
"""Raised when the utt is too short for subsampling.
|
||||
|
||||
Args:
|
||||
message: Error message to display.
|
||||
actual_size: The size that cannot pass the subsampling.
|
||||
limit: The size limit for subsampling.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, actual_size: int, limit: int) -> None:
|
||||
"""Construct a TooShortUttError module."""
|
||||
super().__init__(message)
|
||||
|
||||
self.actual_size = actual_size
|
||||
self.limit = limit
|
||||
|
||||
|
||||
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
|
||||
"""Check if the input is too short for subsampling.
|
||||
|
||||
Args:
|
||||
sub_factor: Subsampling factor for Conv2DSubsampling.
|
||||
size: Input size.
|
||||
|
||||
Returns:
|
||||
: Whether an error should be sent.
|
||||
: Size limit for specified subsampling factor.
|
||||
|
||||
"""
|
||||
if sub_factor == 2 and size < 3:
|
||||
return True, 7
|
||||
elif sub_factor == 4 and size < 7:
|
||||
return True, 7
|
||||
elif sub_factor == 6 and size < 11:
|
||||
return True, 11
|
||||
|
||||
return False, -1
|
||||
|
||||
|
||||
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
|
||||
"""Get conv2D second layer parameters for given subsampling factor.
|
||||
|
||||
Args:
|
||||
sub_factor: Subsampling factor (1/X).
|
||||
input_size: Input size.
|
||||
|
||||
Returns:
|
||||
: Kernel size for second convolution.
|
||||
: Stride for second convolution.
|
||||
: Conv2DSubsampling output size.
|
||||
|
||||
"""
|
||||
if sub_factor == 2:
|
||||
return 3, 1, (((input_size - 1) // 2 - 2))
|
||||
elif sub_factor == 4:
|
||||
return 3, 2, (((input_size - 1) // 2 - 1) // 2)
|
||||
elif sub_factor == 6:
|
||||
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
|
||||
else:
|
||||
raise ValueError(
|
||||
"subsampling_factor parameter should be set to either 2, 4 or 6."
|
||||
)
|
||||
|
||||
|
||||
def make_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
left_chunk_size: int = 0,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
"""Create chunk mask for the subsequent steps (size, size).
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
||||
|
||||
Args:
|
||||
size: Size of the source mask.
|
||||
chunk_size: Number of frames in chunk.
|
||||
left_chunk_size: Size of the left context in chunks (0 means full context).
|
||||
device: Device for the mask tensor.
|
||||
|
||||
Returns:
|
||||
mask: Chunk mask. (size, size)
|
||||
|
||||
"""
|
||||
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
|
||||
for i in range(size):
|
||||
if left_chunk_size < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
|
||||
|
||||
end = min((i // chunk_size + 1) * chunk_size, size)
|
||||
mask[i, start:end] = True
|
||||
|
||||
return ~mask
|
||||
|
||||
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Create source mask for given lengths.
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
||||
|
||||
Args:
|
||||
lengths: Sequence lengths. (B,)
|
||||
|
||||
Returns:
|
||||
: Mask for the sequence lengths. (B, max_len)
|
||||
|
||||
"""
|
||||
max_len = lengths.max()
|
||||
batch_size = lengths.size(0)
|
||||
|
||||
expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
|
||||
|
||||
return expanded_lengths >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
def get_transducer_task_io(
|
||||
labels: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Get Transducer loss I/O.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
encoder_out_lens: Encoder output lengths. (B,)
|
||||
ignore_id: Padding symbol ID.
|
||||
blank_id: Blank symbol ID.
|
||||
|
||||
Returns:
|
||||
decoder_in: Decoder inputs. (B, U)
|
||||
target: Target label ID sequences. (B, U)
|
||||
t_len: Time lengths. (B,)
|
||||
u_len: Label lengths. (B,)
|
||||
|
||||
"""
|
||||
|
||||
def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
|
||||
"""Create padded batch of labels from a list of labels sequences.
|
||||
|
||||
Args:
|
||||
labels: Labels sequences. [B x (?)]
|
||||
padding_value: Padding value.
|
||||
|
||||
Returns:
|
||||
labels: Batch of padded labels sequences. (B,)
|
||||
|
||||
"""
|
||||
batch_size = len(labels)
|
||||
|
||||
padded = (
|
||||
labels[0]
|
||||
.new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
|
||||
.fill_(padding_value)
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
padded[i, : labels[i].size(0)] = labels[i]
|
||||
|
||||
return padded
|
||||
|
||||
device = labels.device
|
||||
|
||||
labels_unpad = [y[y != ignore_id] for y in labels]
|
||||
blank = labels[0].new([blank_id])
|
||||
|
||||
decoder_in = pad_list(
|
||||
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
|
||||
).to(device)
|
||||
|
||||
target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
|
||||
|
||||
encoder_out_lens = list(map(int, encoder_out_lens))
|
||||
t_len = torch.IntTensor(encoder_out_lens).to(device)
|
||||
|
||||
u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
|
||||
|
||||
return decoder_in, target, t_len, u_len
|
||||
|
||||
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
|
||||
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
|
||||
if t.size(dim) == pad_len:
|
||||
return t
|
||||
else:
|
||||
pad_size = list(t.shape)
|
||||
pad_size[dim] = pad_len - t.size(dim)
|
||||
return torch.cat(
|
||||
[t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
|
||||
)
|
||||
58
funasr_local/modules/positionwise_feed_forward.py
Normal file
58
funasr_local/modules/positionwise_feed_forward.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positionwise feed forward layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimenstion.
|
||||
hidden_units (int): The number of hidden units.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
||||
|
||||
|
||||
class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimenstion.
|
||||
hidden_units (int): The number of hidden units.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForwardDecoderSANM, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim if adim is None else adim, bias=False)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
self.norm = LayerNorm(hidden_units)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))
|
||||
124
funasr_local/modules/repeat.py
Normal file
124
funasr_local/modules/repeat.py
Normal file
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Repeat the same layer definition."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiSequential(torch.nn.Sequential):
|
||||
"""Multi-input multi-output torch.nn.Sequential."""
|
||||
|
||||
def forward(self, *args):
|
||||
"""Repeat."""
|
||||
for m in self:
|
||||
args = m(*args)
|
||||
return args
|
||||
|
||||
|
||||
def repeat(N, fn):
|
||||
"""Repeat module N times.
|
||||
|
||||
Args:
|
||||
N (int): Number of repeat time.
|
||||
fn (Callable): Function to generate module.
|
||||
|
||||
Returns:
|
||||
MultiSequential: Repeated model instance.
|
||||
|
||||
"""
|
||||
return MultiSequential(*[fn(n) for n in range(N)])
|
||||
|
||||
|
||||
class MultiBlocks(torch.nn.Module):
|
||||
"""MultiBlocks definition.
|
||||
Args:
|
||||
block_list: Individual blocks of the encoder architecture.
|
||||
output_size: Architecture output size.
|
||||
norm_class: Normalization module class.
|
||||
norm_args: Normalization module arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_list: List[torch.nn.Module],
|
||||
output_size: int,
|
||||
norm_class: torch.nn.Module = torch.nn.LayerNorm,
|
||||
) -> None:
|
||||
"""Construct a MultiBlocks object."""
|
||||
super().__init__()
|
||||
|
||||
self.blocks = torch.nn.ModuleList(block_list)
|
||||
self.norm_blocks = norm_class(output_size)
|
||||
|
||||
self.num_blocks = len(block_list)
|
||||
|
||||
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
|
||||
"""Initialize/Reset encoder streaming cache.
|
||||
Args:
|
||||
left_context: Number of left frames during chunk-by-chunk inference.
|
||||
device: Device to use for cache tensor.
|
||||
"""
|
||||
for idx in range(self.num_blocks):
|
||||
self.blocks[idx].reset_streaming_cache(left_context, device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward each block of the encoder architecture.
|
||||
Args:
|
||||
x: MultiBlocks input sequences. (B, T, D_block_1)
|
||||
pos_enc: Positional embedding sequences.
|
||||
mask: Source mask. (B, T)
|
||||
chunk_mask: Chunk mask. (T_2, T_2)
|
||||
Returns:
|
||||
x: Output sequences. (B, T, D_block_N)
|
||||
"""
|
||||
for block_index, block in enumerate(self.blocks):
|
||||
x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
|
||||
|
||||
x = self.norm_blocks(x)
|
||||
|
||||
return x
|
||||
|
||||
def chunk_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int = 0,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward each block of the encoder architecture.
|
||||
Args:
|
||||
x: MultiBlocks input sequences. (B, T, D_block_1)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
|
||||
mask: Source mask. (B, T_2)
|
||||
left_context: Number of frames in left context.
|
||||
right_context: Number of frames in right context.
|
||||
Returns:
|
||||
x: MultiBlocks output sequences. (B, T, D_block_N)
|
||||
"""
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
x, pos_enc = block.chunk_forward(
|
||||
x,
|
||||
pos_enc,
|
||||
mask,
|
||||
chunk_size=chunk_size,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
)
|
||||
|
||||
x = self.norm_blocks(x)
|
||||
|
||||
return x
|
||||
1
funasr_local/modules/rnn/__init__.py
Normal file
1
funasr_local/modules/rnn/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Initialize sub package."""
|
||||
156
funasr_local/modules/rnn/argument.py
Normal file
156
funasr_local/modules/rnn/argument.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright 2020 Hirofumi Inaguma
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Conformer common arguments."""
|
||||
|
||||
|
||||
def add_arguments_rnn_encoder_common(group):
|
||||
"""Define common arguments for RNN encoder."""
|
||||
group.add_argument(
|
||||
"--etype",
|
||||
default="blstmp",
|
||||
type=str,
|
||||
choices=[
|
||||
"lstm",
|
||||
"blstm",
|
||||
"lstmp",
|
||||
"blstmp",
|
||||
"vgglstmp",
|
||||
"vggblstmp",
|
||||
"vgglstm",
|
||||
"vggblstm",
|
||||
"gru",
|
||||
"bgru",
|
||||
"grup",
|
||||
"bgrup",
|
||||
"vgggrup",
|
||||
"vggbgrup",
|
||||
"vgggru",
|
||||
"vggbgru",
|
||||
],
|
||||
help="Type of encoder network architecture",
|
||||
)
|
||||
group.add_argument(
|
||||
"--elayers",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of encoder layers",
|
||||
)
|
||||
group.add_argument(
|
||||
"--eunits",
|
||||
"-u",
|
||||
default=300,
|
||||
type=int,
|
||||
help="Number of encoder hidden units",
|
||||
)
|
||||
group.add_argument(
|
||||
"--eprojs", default=320, type=int, help="Number of encoder projection units"
|
||||
)
|
||||
group.add_argument(
|
||||
"--subsample",
|
||||
default="1",
|
||||
type=str,
|
||||
help="Subsample input frames x_y_z means "
|
||||
"subsample every x frame at 1st layer, "
|
||||
"every y frame at 2nd layer etc.",
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
def add_arguments_rnn_decoder_common(group):
|
||||
"""Define common arguments for RNN decoder."""
|
||||
group.add_argument(
|
||||
"--dtype",
|
||||
default="lstm",
|
||||
type=str,
|
||||
choices=["lstm", "gru"],
|
||||
help="Type of decoder network architecture",
|
||||
)
|
||||
group.add_argument(
|
||||
"--dlayers", default=1, type=int, help="Number of decoder layers"
|
||||
)
|
||||
group.add_argument(
|
||||
"--dunits", default=320, type=int, help="Number of decoder hidden units"
|
||||
)
|
||||
group.add_argument(
|
||||
"--dropout-rate-decoder",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Dropout rate for the decoder",
|
||||
)
|
||||
group.add_argument(
|
||||
"--sampling-probability",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Ratio of predicted labels fed back to decoder",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lsm-type",
|
||||
const="",
|
||||
default="",
|
||||
type=str,
|
||||
nargs="?",
|
||||
choices=["", "unigram"],
|
||||
help="Apply label smoothing with a specified distribution type",
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
def add_arguments_rnn_attention_common(group):
|
||||
"""Define common arguments for RNN attention."""
|
||||
group.add_argument(
|
||||
"--atype",
|
||||
default="dot",
|
||||
type=str,
|
||||
choices=[
|
||||
"noatt",
|
||||
"dot",
|
||||
"add",
|
||||
"location",
|
||||
"coverage",
|
||||
"coverage_location",
|
||||
"location2d",
|
||||
"location_recurrent",
|
||||
"multi_head_dot",
|
||||
"multi_head_add",
|
||||
"multi_head_loc",
|
||||
"multi_head_multi_res_loc",
|
||||
],
|
||||
help="Type of attention architecture",
|
||||
)
|
||||
group.add_argument(
|
||||
"--adim",
|
||||
default=320,
|
||||
type=int,
|
||||
help="Number of attention transformation dimensions",
|
||||
)
|
||||
group.add_argument(
|
||||
"--awin", default=5, type=int, help="Window size for location2d attention"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aheads",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of heads for multi head attention",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aconv-chans",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="Number of attention convolution channels \
|
||||
(negative value indicates no location-aware attention)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aconv-filts",
|
||||
default=100,
|
||||
type=int,
|
||||
help="Number of attention convolution filters \
|
||||
(negative value indicates no location-aware attention)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--dropout-rate",
|
||||
default=0.0,
|
||||
type=float,
|
||||
help="Dropout rate for the encoder",
|
||||
)
|
||||
return group
|
||||
1808
funasr_local/modules/rnn/attentions.py
Normal file
1808
funasr_local/modules/rnn/attentions.py
Normal file
File diff suppressed because it is too large
Load Diff
1211
funasr_local/modules/rnn/decoders.py
Normal file
1211
funasr_local/modules/rnn/decoders.py
Normal file
File diff suppressed because it is too large
Load Diff
372
funasr_local/modules/rnn/encoders.py
Normal file
372
funasr_local/modules/rnn/encoders.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
from torch.nn.utils.rnn import pad_packed_sequence
|
||||
|
||||
from funasr_local.modules.e2e_asr_common import get_vgg2l_odim
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.nets_utils import to_device
|
||||
|
||||
|
||||
class RNNP(torch.nn.Module):
|
||||
"""RNN with projection layer module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of projection units
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
|
||||
super(RNNP, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
for i in six.moves.range(elayers):
|
||||
if i == 0:
|
||||
inputdim = idim
|
||||
else:
|
||||
inputdim = hdim
|
||||
|
||||
RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU
|
||||
rnn = RNN(
|
||||
inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True
|
||||
)
|
||||
|
||||
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
|
||||
|
||||
# bottleneck layer to merge
|
||||
if bidir:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
|
||||
else:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
|
||||
|
||||
self.elayers = elayers
|
||||
self.cdim = cdim
|
||||
self.subsample = subsample
|
||||
self.typ = typ
|
||||
self.bidir = bidir
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNNP forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, hdim)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
|
||||
elayer_states = []
|
||||
for layer in six.moves.range(self.elayers):
|
||||
if not isinstance(ilens, torch.Tensor):
|
||||
ilens = torch.tensor(ilens)
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
|
||||
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
|
||||
rnn.flatten_parameters()
|
||||
if prev_state is not None and rnn.bidirectional:
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = rnn(
|
||||
xs_pack, hx=None if prev_state is None else prev_state[layer]
|
||||
)
|
||||
elayer_states.append(states)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
sub = self.subsample[layer + 1]
|
||||
if sub > 1:
|
||||
ys_pad = ys_pad[:, ::sub]
|
||||
ilens = torch.tensor([int(i + 1) // sub for i in ilens])
|
||||
# (sum _utt frame_utt) x dim
|
||||
projection_layer = getattr(self, "bt%d" % layer)
|
||||
projected = projection_layer(ys_pad.contiguous().view(-1, ys_pad.size(2)))
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
if layer < self.elayers - 1:
|
||||
xs_pad = torch.tanh(F.dropout(xs_pad, p=self.dropout))
|
||||
|
||||
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
|
||||
|
||||
|
||||
class RNN(torch.nn.Module):
|
||||
"""RNN module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of final projection units
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
|
||||
super(RNN, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
self.nbrnn = (
|
||||
torch.nn.LSTM(
|
||||
idim,
|
||||
cdim,
|
||||
elayers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
bidirectional=bidir,
|
||||
)
|
||||
if "lstm" in typ
|
||||
else torch.nn.GRU(
|
||||
idim,
|
||||
cdim,
|
||||
elayers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
bidirectional=bidir,
|
||||
)
|
||||
)
|
||||
if bidir:
|
||||
self.l_last = torch.nn.Linear(cdim * 2, hdim)
|
||||
else:
|
||||
self.l_last = torch.nn.Linear(cdim, hdim)
|
||||
self.typ = typ
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNN forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
|
||||
if not isinstance(ilens, torch.Tensor):
|
||||
ilens = torch.tensor(ilens)
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
|
||||
self.nbrnn.flatten_parameters()
|
||||
if prev_state is not None and self.nbrnn.bidirectional:
|
||||
# We assume that when previous state is passed,
|
||||
# it means that we're streaming the input
|
||||
# and therefore cannot propagate backward BRNN state
|
||||
# (otherwise it goes in the wrong direction)
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = self.nbrnn(xs_pack, hx=prev_state)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = torch.tanh(
|
||||
self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2)))
|
||||
)
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
return xs_pad, ilens, states # x: utt list of frame x dim
|
||||
|
||||
|
||||
def reset_backward_rnn_state(states):
|
||||
"""Sets backward BRNN states to zeroes
|
||||
|
||||
Useful in processing of sliding windows over the inputs
|
||||
"""
|
||||
if isinstance(states, (list, tuple)):
|
||||
for state in states:
|
||||
state[1::2] = 0.0
|
||||
else:
|
||||
states[1::2] = 0.0
|
||||
return states
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG-like module
|
||||
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, in_channel=1):
|
||||
super(VGG2L, self).__init__()
|
||||
# CNN layer (VGG motivated)
|
||||
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
|
||||
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
|
||||
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
|
||||
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
|
||||
|
||||
self.in_channel = in_channel
|
||||
|
||||
def forward(self, xs_pad, ilens, **kwargs):
|
||||
"""VGG2L forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
|
||||
|
||||
# x: utt x frame x dim
|
||||
# xs_pad = F.pad_sequence(xs_pad)
|
||||
|
||||
# x: utt x 1 (input channel num) x frame x dim
|
||||
xs_pad = xs_pad.view(
|
||||
xs_pad.size(0),
|
||||
xs_pad.size(1),
|
||||
self.in_channel,
|
||||
xs_pad.size(2) // self.in_channel,
|
||||
).transpose(1, 2)
|
||||
|
||||
# NOTE: max_pool1d ?
|
||||
xs_pad = F.relu(self.conv1_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv1_2(xs_pad))
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
|
||||
|
||||
xs_pad = F.relu(self.conv2_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv2_2(xs_pad))
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
|
||||
if torch.is_tensor(ilens):
|
||||
ilens = ilens.cpu().numpy()
|
||||
else:
|
||||
ilens = np.array(ilens, dtype=np.float32)
|
||||
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
|
||||
ilens = np.array(
|
||||
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64
|
||||
).tolist()
|
||||
|
||||
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
|
||||
xs_pad = xs_pad.transpose(1, 2)
|
||||
xs_pad = xs_pad.contiguous().view(
|
||||
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)
|
||||
)
|
||||
return xs_pad, ilens, None # no state in this layer
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Encoder module
|
||||
|
||||
:param str etype: type of encoder network
|
||||
:param int idim: number of dimensions of encoder network
|
||||
:param int elayers: number of layers of encoder network
|
||||
:param int eunits: number of lstm units of encoder network
|
||||
:param int eprojs: number of projection units of encoder network
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1
|
||||
):
|
||||
super(Encoder, self).__init__()
|
||||
typ = etype.lstrip("vgg").rstrip("p")
|
||||
if typ not in ["lstm", "gru", "blstm", "bgru"]:
|
||||
logging.error("Error: need to specify an appropriate encoder architecture")
|
||||
|
||||
if etype.startswith("vgg"):
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[
|
||||
VGG2L(in_channel),
|
||||
RNNP(
|
||||
get_vgg2l_odim(idim, in_channel=in_channel),
|
||||
elayers,
|
||||
eunits,
|
||||
eprojs,
|
||||
subsample,
|
||||
dropout,
|
||||
typ=typ,
|
||||
),
|
||||
]
|
||||
)
|
||||
logging.info("Use CNN-VGG + " + typ.upper() + "P for encoder")
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[
|
||||
VGG2L(in_channel),
|
||||
RNN(
|
||||
get_vgg2l_odim(idim, in_channel=in_channel),
|
||||
elayers,
|
||||
eunits,
|
||||
eprojs,
|
||||
dropout,
|
||||
typ=typ,
|
||||
),
|
||||
]
|
||||
)
|
||||
logging.info("Use CNN-VGG + " + typ.upper() + " for encoder")
|
||||
self.conv_subsampling_factor = 4
|
||||
else:
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)]
|
||||
)
|
||||
logging.info(typ.upper() + " with every-layer projection for encoder")
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)]
|
||||
)
|
||||
logging.info(typ.upper() + " without projection for encoder")
|
||||
self.conv_subsampling_factor = 1
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_states=None):
|
||||
"""Encoder forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if prev_states is None:
|
||||
prev_states = [None] * len(self.enc)
|
||||
assert len(prev_states) == len(self.enc)
|
||||
|
||||
current_states = []
|
||||
for module, prev_state in zip(self.enc, prev_states):
|
||||
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
|
||||
current_states.append(states)
|
||||
|
||||
# make mask to remove bias value in padded part
|
||||
mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1))
|
||||
|
||||
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
|
||||
|
||||
|
||||
def encoder_for(args, idim, subsample):
|
||||
"""Instantiates an encoder module given the program arguments
|
||||
|
||||
:param Namespace args: The arguments
|
||||
:param int or List of integer idim: dimension of input, e.g. 83, or
|
||||
List of dimensions of inputs, e.g. [83,83]
|
||||
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
|
||||
List of subsample factors of each encoder.
|
||||
e.g. [[1,2,2,1,1], [1,2,2,1,1]]
|
||||
:rtype torch.nn.Module
|
||||
:return: The encoder module
|
||||
"""
|
||||
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
|
||||
if num_encs == 1:
|
||||
# compatible with single encoder asr mode
|
||||
return Encoder(
|
||||
args.etype,
|
||||
idim,
|
||||
args.elayers,
|
||||
args.eunits,
|
||||
args.eprojs,
|
||||
subsample,
|
||||
args.dropout_rate,
|
||||
)
|
||||
elif num_encs >= 1:
|
||||
enc_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
enc = Encoder(
|
||||
args.etype[idx],
|
||||
idim[idx],
|
||||
args.elayers[idx],
|
||||
args.eunits[idx],
|
||||
args.eprojs,
|
||||
subsample[idx],
|
||||
args.dropout_rate[idx],
|
||||
)
|
||||
enc_list.append(enc)
|
||||
return enc_list
|
||||
else:
|
||||
raise ValueError(
|
||||
"Number of encoders needs to be more than one. {}".format(num_encs)
|
||||
)
|
||||
1
funasr_local/modules/scorers/__init__.py
Normal file
1
funasr_local/modules/scorers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Initialize sub package."""
|
||||
158
funasr_local/modules/scorers/ctc.py
Normal file
158
funasr_local/modules/scorers/ctc.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""ScorerInterface implementation for CTC."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.scorers.ctc_prefix_score import CTCPrefixScore
|
||||
from funasr_local.modules.scorers.ctc_prefix_score import CTCPrefixScoreTH
|
||||
from funasr_local.modules.scorers.scorer_interface import BatchPartialScorerInterface
|
||||
|
||||
|
||||
class CTCPrefixScorer(BatchPartialScorerInterface):
|
||||
"""Decoder interface wrapper for CTCPrefixScore."""
|
||||
|
||||
def __init__(self, ctc: torch.nn.Module, eos: int):
|
||||
"""Initialize class.
|
||||
|
||||
Args:
|
||||
ctc (torch.nn.Module): The CTC implementation.
|
||||
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
|
||||
eos (int): The end-of-sequence id.
|
||||
|
||||
"""
|
||||
self.ctc = ctc
|
||||
self.eos = eos
|
||||
self.impl = None
|
||||
|
||||
def init_state(self, x: torch.Tensor):
|
||||
"""Get an initial state for decoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: initial state
|
||||
|
||||
"""
|
||||
logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
|
||||
# TODO(karita): use CTCPrefixScoreTH
|
||||
self.impl = CTCPrefixScore(logp, 0, self.eos, np)
|
||||
return 0, self.impl.initial_state()
|
||||
|
||||
def select_state(self, state, i, new_id=None):
|
||||
"""Select state with relative ids in the main beam search.
|
||||
|
||||
Args:
|
||||
state: Decoder state for prefix tokens
|
||||
i (int): Index to select a state in the main beam search
|
||||
new_id (int): New label id to select a state if necessary
|
||||
|
||||
Returns:
|
||||
state: pruned state
|
||||
|
||||
"""
|
||||
if type(state) == tuple:
|
||||
if len(state) == 2: # for CTCPrefixScore
|
||||
sc, st = state
|
||||
return sc[i], st[i]
|
||||
else: # for CTCPrefixScoreTH (need new_id > 0)
|
||||
r, log_psi, f_min, f_max, scoring_idmap = state
|
||||
s = log_psi[i, new_id].expand(log_psi.size(1))
|
||||
if scoring_idmap is not None:
|
||||
return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
|
||||
else:
|
||||
return r[:, :, i, new_id], s, f_min, f_max
|
||||
return None if state is None else state[i]
|
||||
|
||||
def score_partial(self, y, ids, state, x):
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D prefix token
|
||||
next_tokens (torch.Tensor): torch.int64 next token to score
|
||||
state: decoder state for prefix tokens
|
||||
x (torch.Tensor): 2D encoder feature that generates ys
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]:
|
||||
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
prev_score, state = state
|
||||
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
|
||||
tscore = torch.as_tensor(
|
||||
presub_score - prev_score, device=x.device, dtype=x.dtype
|
||||
)
|
||||
return tscore, (presub_score, new_st)
|
||||
|
||||
def batch_init_state(self, x: torch.Tensor):
|
||||
"""Get an initial state for decoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: initial state
|
||||
|
||||
"""
|
||||
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
|
||||
xlen = torch.tensor([logp.size(1)])
|
||||
self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
|
||||
return None
|
||||
|
||||
def batch_score_partial(self, y, ids, state, x):
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D prefix token
|
||||
ids (torch.Tensor): torch.int64 next token to score
|
||||
state: decoder state for prefix tokens
|
||||
x (torch.Tensor): 2D encoder feature that generates ys
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]:
|
||||
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
batch_state = (
|
||||
(
|
||||
torch.stack([s[0] for s in state], dim=2),
|
||||
torch.stack([s[1] for s in state]),
|
||||
state[0][2],
|
||||
state[0][3],
|
||||
)
|
||||
if state[0] is not None
|
||||
else None
|
||||
)
|
||||
return self.impl(y, batch_state, ids)
|
||||
|
||||
def extend_prob(self, x: torch.Tensor):
|
||||
"""Extend probs for decoding.
|
||||
|
||||
This extension is for streaming decoding
|
||||
as in Eq (14) in https://arxiv.org/abs/2006.14941
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoded feature tensor
|
||||
|
||||
"""
|
||||
logp = self.ctc.log_softmax(x.unsqueeze(0))
|
||||
self.impl.extend_prob(logp)
|
||||
|
||||
def extend_state(self, state):
|
||||
"""Extend state for decoding.
|
||||
|
||||
This extension is for streaming decoding
|
||||
as in Eq (14) in https://arxiv.org/abs/2006.14941
|
||||
|
||||
Args:
|
||||
state: The states of hyps
|
||||
|
||||
Returns: exteded state
|
||||
|
||||
"""
|
||||
new_state = []
|
||||
for s in state:
|
||||
new_state.append(self.impl.extend_state(s))
|
||||
|
||||
return new_state
|
||||
359
funasr_local/modules/scorers/ctc_prefix_score.py
Normal file
359
funasr_local/modules/scorers/ctc_prefix_score.py
Normal file
@@ -0,0 +1,359 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
|
||||
class CTCPrefixScoreTH(object):
|
||||
"""Batch processing of CTCPrefixScore
|
||||
|
||||
which is based on Algorithm 2 in WATANABE et al.
|
||||
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
||||
but extended to efficiently compute the label probablities for multiple
|
||||
hypotheses simultaneously
|
||||
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
|
||||
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
|
||||
"""
|
||||
|
||||
def __init__(self, x, xlens, blank, eos, margin=0):
|
||||
"""Construct CTC prefix scorer
|
||||
|
||||
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
||||
:param torch.Tensor xlens: input lengths (B,)
|
||||
:param int blank: blank label id
|
||||
:param int eos: end-of-sequence id
|
||||
:param int margin: margin parameter for windowing (0 means no windowing)
|
||||
"""
|
||||
# In the comment lines,
|
||||
# we assume T: input_length, B: batch size, W: beam width, O: output dim.
|
||||
self.logzero = -10000000000.0
|
||||
self.blank = blank
|
||||
self.eos = eos
|
||||
self.batch = x.size(0)
|
||||
self.input_length = x.size(1)
|
||||
self.odim = x.size(2)
|
||||
self.dtype = x.dtype
|
||||
self.device = (
|
||||
torch.device("cuda:%d" % x.get_device())
|
||||
if x.is_cuda
|
||||
else torch.device("cpu")
|
||||
)
|
||||
# Pad the rest of posteriors in the batch
|
||||
# TODO(takaaki-hori): need a better way without for-loops
|
||||
for i, l in enumerate(xlens):
|
||||
if l < self.input_length:
|
||||
x[i, l:, :] = self.logzero
|
||||
x[i, l:, blank] = 0
|
||||
# Reshape input x
|
||||
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
||||
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
||||
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
||||
self.end_frames = torch.as_tensor(xlens) - 1
|
||||
|
||||
# Setup CTC windowing
|
||||
self.margin = margin
|
||||
if margin > 0:
|
||||
self.frame_ids = torch.arange(
|
||||
self.input_length, dtype=self.dtype, device=self.device
|
||||
)
|
||||
# Base indices for index conversion
|
||||
self.idx_bh = None
|
||||
self.idx_b = torch.arange(self.batch, device=self.device)
|
||||
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
|
||||
|
||||
def __call__(self, y, state, scoring_ids=None, att_w=None):
|
||||
"""Compute CTC prefix scores for next labels
|
||||
|
||||
:param list y: prefix label sequences
|
||||
:param tuple state: previous CTC state
|
||||
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
|
||||
:param torch.Tensor att_w: attention weights to decide CTC window
|
||||
:return new_state, ctc_local_scores (BW, O)
|
||||
"""
|
||||
output_length = len(y[0]) - 1 # ignore sos
|
||||
last_ids = [yi[-1] for yi in y] # last output label ids
|
||||
n_bh = len(last_ids) # batch * hyps
|
||||
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
|
||||
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
|
||||
# prepare state info
|
||||
if state is None:
|
||||
r_prev = torch.full(
|
||||
(self.input_length, 2, self.batch, n_hyps),
|
||||
self.logzero,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
|
||||
r_prev = r_prev.view(-1, 2, n_bh)
|
||||
s_prev = 0.0
|
||||
f_min_prev = 0
|
||||
f_max_prev = 1
|
||||
else:
|
||||
r_prev, s_prev, f_min_prev, f_max_prev = state
|
||||
|
||||
# select input dimensions for scoring
|
||||
if self.scoring_num > 0:
|
||||
scoring_idmap = torch.full(
|
||||
(n_bh, self.odim), -1, dtype=torch.long, device=self.device
|
||||
)
|
||||
snum = self.scoring_num
|
||||
if self.idx_bh is None or n_bh > len(self.idx_bh):
|
||||
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
|
||||
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
|
||||
snum, device=self.device
|
||||
)
|
||||
scoring_idx = (
|
||||
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
|
||||
).view(-1)
|
||||
x_ = torch.index_select(
|
||||
self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
|
||||
).view(2, -1, n_bh, snum)
|
||||
else:
|
||||
scoring_ids = None
|
||||
scoring_idmap = None
|
||||
snum = self.odim
|
||||
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
|
||||
|
||||
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
|
||||
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
|
||||
r = torch.full(
|
||||
(self.input_length, 2, n_bh, snum),
|
||||
self.logzero,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if output_length == 0:
|
||||
r[0, 0] = x_[0, 0]
|
||||
|
||||
r_sum = torch.logsumexp(r_prev, 1)
|
||||
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
|
||||
if scoring_ids is not None:
|
||||
for idx in range(n_bh):
|
||||
pos = scoring_idmap[idx, last_ids[idx]]
|
||||
if pos >= 0:
|
||||
log_phi[:, idx, pos] = r_prev[:, 1, idx]
|
||||
else:
|
||||
for idx in range(n_bh):
|
||||
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
|
||||
|
||||
# decide start and end frames based on attention weights
|
||||
if att_w is not None and self.margin > 0:
|
||||
f_arg = torch.matmul(att_w, self.frame_ids)
|
||||
f_min = max(int(f_arg.min().cpu()), f_min_prev)
|
||||
f_max = max(int(f_arg.max().cpu()), f_max_prev)
|
||||
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
|
||||
end = min(f_max + self.margin, self.input_length)
|
||||
else:
|
||||
f_min = f_max = 0
|
||||
start = max(output_length, 1)
|
||||
end = self.input_length
|
||||
|
||||
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
|
||||
for t in range(start, end):
|
||||
rp = r[t - 1]
|
||||
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
|
||||
2, 2, n_bh, snum
|
||||
)
|
||||
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
|
||||
|
||||
# compute log prefix probabilities log(psi)
|
||||
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
|
||||
if scoring_ids is not None:
|
||||
log_psi = torch.full(
|
||||
(n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
|
||||
)
|
||||
log_psi_ = torch.logsumexp(
|
||||
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
||||
dim=0,
|
||||
)
|
||||
for si in range(n_bh):
|
||||
log_psi[si, scoring_ids[si]] = log_psi_[si]
|
||||
else:
|
||||
log_psi = torch.logsumexp(
|
||||
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
for si in range(n_bh):
|
||||
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
|
||||
|
||||
# exclude blank probs
|
||||
log_psi[:, self.blank] = self.logzero
|
||||
|
||||
return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
|
||||
|
||||
def index_select_state(self, state, best_ids):
|
||||
"""Select CTC states according to best ids
|
||||
|
||||
:param state : CTC state
|
||||
:param best_ids : index numbers selected by beam pruning (B, W)
|
||||
:return selected_state
|
||||
"""
|
||||
r, s, f_min, f_max, scoring_idmap = state
|
||||
# convert ids to BHO space
|
||||
n_bh = len(s)
|
||||
n_hyps = n_bh // self.batch
|
||||
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
|
||||
# select hypothesis scores
|
||||
s_new = torch.index_select(s.view(-1), 0, vidx)
|
||||
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
|
||||
# convert ids to BHS space (S: scoring_num)
|
||||
if scoring_idmap is not None:
|
||||
snum = self.scoring_num
|
||||
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
|
||||
-1
|
||||
)
|
||||
label_ids = torch.fmod(best_ids, self.odim).view(-1)
|
||||
score_idx = scoring_idmap[hyp_idx, label_ids]
|
||||
score_idx[score_idx == -1] = 0
|
||||
vidx = score_idx + hyp_idx * snum
|
||||
else:
|
||||
snum = self.odim
|
||||
# select forward probabilities
|
||||
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
|
||||
-1, 2, n_bh
|
||||
)
|
||||
return r_new, s_new, f_min, f_max
|
||||
|
||||
def extend_prob(self, x):
|
||||
"""Extend CTC prob.
|
||||
|
||||
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
||||
"""
|
||||
|
||||
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
|
||||
# Pad the rest of posteriors in the batch
|
||||
# TODO(takaaki-hori): need a better way without for-loops
|
||||
xlens = [x.size(1)]
|
||||
for i, l in enumerate(xlens):
|
||||
if l < self.input_length:
|
||||
x[i, l:, :] = self.logzero
|
||||
x[i, l:, self.blank] = 0
|
||||
tmp_x = self.x
|
||||
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
||||
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
||||
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
||||
self.x[:, : tmp_x.shape[1], :, :] = tmp_x
|
||||
self.input_length = x.size(1)
|
||||
self.end_frames = torch.as_tensor(xlens) - 1
|
||||
|
||||
def extend_state(self, state):
|
||||
"""Compute CTC prefix state.
|
||||
|
||||
|
||||
:param state : CTC state
|
||||
:return ctc_state
|
||||
"""
|
||||
|
||||
if state is None:
|
||||
# nothing to do
|
||||
return state
|
||||
else:
|
||||
r_prev, s_prev, f_min_prev, f_max_prev = state
|
||||
|
||||
r_prev_new = torch.full(
|
||||
(self.input_length, 2),
|
||||
self.logzero,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
start = max(r_prev.shape[0], 1)
|
||||
r_prev_new[0:start] = r_prev
|
||||
for t in six.moves.range(start, self.input_length):
|
||||
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
|
||||
|
||||
return (r_prev_new, s_prev, f_min_prev, f_max_prev)
|
||||
|
||||
|
||||
class CTCPrefixScore(object):
|
||||
"""Compute CTC label sequence scores
|
||||
|
||||
which is based on Algorithm 2 in WATANABE et al.
|
||||
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
||||
but extended to efficiently compute the probablities of multiple labels
|
||||
simultaneously
|
||||
"""
|
||||
|
||||
def __init__(self, x, blank, eos, xp):
|
||||
self.xp = xp
|
||||
self.logzero = -10000000000.0
|
||||
self.blank = blank
|
||||
self.eos = eos
|
||||
self.input_length = len(x)
|
||||
self.x = x
|
||||
|
||||
def initial_state(self):
|
||||
"""Obtain an initial CTC state
|
||||
|
||||
:return: CTC state
|
||||
"""
|
||||
# initial CTC state is made of a frame x 2 tensor that corresponds to
|
||||
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
|
||||
# superscripts n and b (non-blank and blank), respectively.
|
||||
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
|
||||
r[0, 1] = self.x[0, self.blank]
|
||||
for i in six.moves.range(1, self.input_length):
|
||||
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
|
||||
return r
|
||||
|
||||
def __call__(self, y, cs, r_prev):
|
||||
"""Compute CTC prefix scores for next labels
|
||||
|
||||
:param y : prefix label sequence
|
||||
:param cs : array of next labels
|
||||
:param r_prev: previous CTC state
|
||||
:return ctc_scores, ctc_states
|
||||
"""
|
||||
# initialize CTC states
|
||||
output_length = len(y) - 1 # ignore sos
|
||||
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
|
||||
# that corresponds to r_t^n(h) and r_t^b(h).
|
||||
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
|
||||
xs = self.x[:, cs]
|
||||
if output_length == 0:
|
||||
r[0, 0] = xs[0]
|
||||
r[0, 1] = self.logzero
|
||||
else:
|
||||
r[output_length - 1] = self.logzero
|
||||
|
||||
# prepare forward probabilities for the last label
|
||||
r_sum = self.xp.logaddexp(
|
||||
r_prev[:, 0], r_prev[:, 1]
|
||||
) # log(r_t^n(g) + r_t^b(g))
|
||||
last = y[-1]
|
||||
if output_length > 0 and last in cs:
|
||||
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
|
||||
for i in six.moves.range(len(cs)):
|
||||
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
|
||||
else:
|
||||
log_phi = r_sum
|
||||
|
||||
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
|
||||
# and log prefix probabilities log(psi)
|
||||
start = max(output_length, 1)
|
||||
log_psi = r[start - 1, 0]
|
||||
for t in six.moves.range(start, self.input_length):
|
||||
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
|
||||
r[t, 1] = (
|
||||
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
|
||||
)
|
||||
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
|
||||
|
||||
# get P(...eos|X) that ends with the prefix itself
|
||||
eos_pos = self.xp.where(cs == self.eos)[0]
|
||||
if len(eos_pos) > 0:
|
||||
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
|
||||
|
||||
# exclude blank probs
|
||||
blank_pos = self.xp.where(cs == self.blank)[0]
|
||||
if len(blank_pos) > 0:
|
||||
log_psi[blank_pos] = self.logzero
|
||||
|
||||
# return the log prefix probability and CTC states, where the label axis
|
||||
# of the CTC states is moved to the first axis to slice it easily
|
||||
return log_psi, self.xp.rollaxis(r, 2)
|
||||
61
funasr_local/modules/scorers/length_bonus.py
Normal file
61
funasr_local/modules/scorers/length_bonus.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Length bonus module."""
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class LengthBonus(BatchScorerInterface):
|
||||
"""Length bonus in beam search."""
|
||||
|
||||
def __init__(self, n_vocab: int):
|
||||
"""Initialize class.
|
||||
|
||||
Args:
|
||||
n_vocab (int): The number of tokens in vocabulary for beam search
|
||||
|
||||
"""
|
||||
self.n = n_vocab
|
||||
|
||||
def score(self, y, state, x):
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): 2D encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
torch.float32 scores for next token (n_vocab)
|
||||
and None
|
||||
|
||||
"""
|
||||
return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
|
||||
|
||||
def batch_score(
|
||||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
return (
|
||||
torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
|
||||
ys.shape[0], self.n
|
||||
),
|
||||
None,
|
||||
)
|
||||
188
funasr_local/modules/scorers/scorer_interface.py
Normal file
188
funasr_local/modules/scorers/scorer_interface.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Scorer interface module."""
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
|
||||
class ScorerInterface:
|
||||
"""Scorer interface for beam search.
|
||||
|
||||
The scorer performs scoring of the all tokens in vocabulary.
|
||||
|
||||
Examples:
|
||||
* Search heuristics
|
||||
* :class:`espnet.nets.scorers.length_bonus.LengthBonus`
|
||||
* Decoder networks of the sequence-to-sequence models
|
||||
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
|
||||
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
|
||||
* Neural language models
|
||||
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
|
||||
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
|
||||
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
|
||||
|
||||
"""
|
||||
|
||||
def init_state(self, x: torch.Tensor) -> Any:
|
||||
"""Get an initial state for decoding (optional).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: initial state
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
|
||||
"""Select state with relative ids in the main beam search.
|
||||
|
||||
Args:
|
||||
state: Decoder state for prefix tokens
|
||||
i (int): Index to select a state in the main beam search
|
||||
new_id (int): New label index to select a state if necessary
|
||||
|
||||
Returns:
|
||||
state: pruned state
|
||||
|
||||
"""
|
||||
return None if state is None else state[i]
|
||||
|
||||
def score(
|
||||
self, y: torch.Tensor, state: Any, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token (required).
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): The encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
scores for next token that has a shape of `(n_vocab)`
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def final_score(self, state: Any) -> float:
|
||||
"""Score eos (optional).
|
||||
|
||||
Args:
|
||||
state: Scorer state for prefix tokens
|
||||
|
||||
Returns:
|
||||
float: final score
|
||||
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
|
||||
class BatchScorerInterface(ScorerInterface):
|
||||
"""Batch scorer interface."""
|
||||
|
||||
def batch_init_state(self, x: torch.Tensor) -> Any:
|
||||
"""Get an initial state for decoding (optional).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: initial state
|
||||
|
||||
"""
|
||||
return self.init_state(x)
|
||||
|
||||
def batch_score(
|
||||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch (required).
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"{} batch score is implemented through for loop not parallelized".format(
|
||||
self.__class__.__name__
|
||||
)
|
||||
)
|
||||
scores = list()
|
||||
outstates = list()
|
||||
for i, (y, state, x) in enumerate(zip(ys, states, xs)):
|
||||
score, outstate = self.score(y, state, x)
|
||||
outstates.append(outstate)
|
||||
scores.append(score)
|
||||
scores = torch.cat(scores, 0).view(ys.shape[0], -1)
|
||||
return scores, outstates
|
||||
|
||||
|
||||
class PartialScorerInterface(ScorerInterface):
|
||||
"""Partial scorer interface for beam search.
|
||||
|
||||
The partial scorer performs scoring when non-partial scorer finished scoring,
|
||||
and receives pre-pruned next tokens to score because it is too heavy to score
|
||||
all the tokens.
|
||||
|
||||
Examples:
|
||||
* Prefix search for connectionist-temporal-classification models
|
||||
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
|
||||
|
||||
"""
|
||||
|
||||
def score_partial(
|
||||
self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token (required).
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D prefix token
|
||||
next_tokens (torch.Tensor): torch.int64 next token to score
|
||||
state: decoder state for prefix tokens
|
||||
x (torch.Tensor): The encoder feature that generates ys
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]:
|
||||
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
|
||||
"""Batch partial scorer interface for beam search."""
|
||||
|
||||
def batch_score_partial(
|
||||
self,
|
||||
ys: torch.Tensor,
|
||||
next_tokens: torch.Tensor,
|
||||
states: List[Any],
|
||||
xs: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token (required).
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]:
|
||||
Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
|
||||
and next states for ys
|
||||
"""
|
||||
raise NotImplementedError
|
||||
0
funasr_local/modules/streaming_utils/__init__.py
Normal file
0
funasr_local/modules/streaming_utils/__init__.py
Normal file
390
funasr_local/modules/streaming_utils/chunk_utilis.py
Normal file
390
funasr_local/modules/streaming_utils/chunk_utilis.py
Normal file
@@ -0,0 +1,390 @@
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from funasr_local.modules.streaming_utils.utils import sequence_mask
|
||||
|
||||
|
||||
|
||||
class overlap_chunk():
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
chunk_size: tuple = (16,),
|
||||
stride: tuple = (10,),
|
||||
pad_left: tuple = (0,),
|
||||
encoder_att_look_back_factor: tuple = (1,),
|
||||
shfit_fsmn: int = 0,
|
||||
decoder_att_look_back_factor: tuple = (1,),
|
||||
):
|
||||
|
||||
pad_left = self.check_chunk_size_args(chunk_size, pad_left)
|
||||
encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
|
||||
decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
|
||||
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
|
||||
self.shfit_fsmn = shfit_fsmn
|
||||
self.x_add_mask = None
|
||||
self.x_rm_mask = None
|
||||
self.x_len = None
|
||||
self.mask_shfit_chunk = None
|
||||
self.mask_chunk_predictor = None
|
||||
self.mask_att_chunk_encoder = None
|
||||
self.mask_shift_att_chunk_decoder = None
|
||||
self.chunk_outs = None
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
|
||||
= None, None, None, None, None
|
||||
|
||||
def check_chunk_size_args(self, chunk_size, x):
|
||||
if len(x) < len(chunk_size):
|
||||
x = [x[0] for i in chunk_size]
|
||||
return x
|
||||
|
||||
def get_chunk_size(self,
|
||||
ind: int = 0
|
||||
):
|
||||
# with torch.no_grad:
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
|
||||
self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
|
||||
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
|
||||
|
||||
def random_choice(self, training=True, decoding_ind=None):
|
||||
chunk_num = len(self.chunk_size)
|
||||
ind = 0
|
||||
if training and chunk_num > 1:
|
||||
ind = torch.randint(0, chunk_num-1, ()).cpu().item()
|
||||
if not training and decoding_ind is not None:
|
||||
ind = int(decoding_ind)
|
||||
|
||||
return ind
|
||||
|
||||
|
||||
|
||||
|
||||
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
|
||||
|
||||
with torch.no_grad():
|
||||
x_len = x_len.cpu().numpy()
|
||||
x_len_max = x_len.max()
|
||||
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
|
||||
shfit_fsmn = self.shfit_fsmn
|
||||
pad_right = chunk_size - stride - pad_left
|
||||
|
||||
chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
|
||||
x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
|
||||
x_len_chunk = x_len_chunk.astype(x_len.dtype)
|
||||
x_len_chunk_max = x_len_chunk.max()
|
||||
|
||||
chunk_num = int(math.ceil(x_len_max/stride))
|
||||
dtype = np.int32
|
||||
max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
|
||||
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
|
||||
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
|
||||
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
|
||||
mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
|
||||
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
|
||||
mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
|
||||
for chunk_ids in range(chunk_num):
|
||||
# x_mask add
|
||||
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
|
||||
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
|
||||
x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
|
||||
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
|
||||
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
|
||||
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
|
||||
|
||||
# x_mask rm
|
||||
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
|
||||
padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
|
||||
padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
|
||||
x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
|
||||
x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
|
||||
x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
|
||||
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
|
||||
x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
|
||||
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
|
||||
|
||||
# fsmn_padding_mask
|
||||
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
|
||||
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
|
||||
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
|
||||
mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
|
||||
|
||||
# predictor mask
|
||||
zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
|
||||
zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
|
||||
mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
|
||||
mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
|
||||
|
||||
# encoder att mask
|
||||
zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
|
||||
zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
|
||||
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_2_mid = np.ones([stride, stride], dtype=dtype)
|
||||
zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
|
||||
zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
|
||||
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
|
||||
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
|
||||
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
|
||||
|
||||
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
|
||||
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
|
||||
|
||||
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
|
||||
zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
|
||||
mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
|
||||
mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
|
||||
|
||||
|
||||
# decoder fsmn_shift_att_mask
|
||||
zeros_1 = np.zeros([shfit_fsmn, 1])
|
||||
ones_1 = np.ones([chunk_size, 1])
|
||||
mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
|
||||
mask_shift_att_chunk_decoder = np.concatenate(
|
||||
[mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
|
||||
|
||||
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
|
||||
self.x_len_chunk = x_len_chunk
|
||||
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
|
||||
self.x_len = x_len
|
||||
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
|
||||
self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
|
||||
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
|
||||
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
|
||||
self.chunk_outs = (self.x_add_mask,
|
||||
self.x_len_chunk,
|
||||
self.x_rm_mask,
|
||||
self.x_len,
|
||||
self.mask_shfit_chunk,
|
||||
self.mask_chunk_predictor,
|
||||
self.mask_att_chunk_encoder,
|
||||
self.mask_shift_att_chunk_decoder)
|
||||
|
||||
return self.chunk_outs
|
||||
|
||||
|
||||
def split_chunk(self, x, x_len, chunk_outs):
|
||||
"""
|
||||
:param x: (b, t, d)
|
||||
:param x_length: (b)
|
||||
:param ind: int
|
||||
:return:
|
||||
"""
|
||||
x = x[:, :x_len.max(), :]
|
||||
b, t, d = x.size()
|
||||
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
|
||||
x.device)
|
||||
x *= x_len_mask[:, :, None]
|
||||
|
||||
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
|
||||
x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
|
||||
pad = (0, 0, self.pad_left_cur, 0)
|
||||
x = F.pad(x, pad, "constant", 0.0)
|
||||
b, t, d = x.size()
|
||||
x = torch.transpose(x, 1, 0)
|
||||
x = torch.reshape(x, [t, -1])
|
||||
x_chunk = torch.mm(x_add_mask, x)
|
||||
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
|
||||
|
||||
return x_chunk, x_len_chunk
|
||||
|
||||
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
|
||||
x_chunk = x_chunk[:, :x_len_chunk.max(), :]
|
||||
b, t, d = x_chunk.size()
|
||||
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
|
||||
x_chunk.device)
|
||||
x_chunk *= x_len_chunk_mask[:, :, None]
|
||||
|
||||
x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
|
||||
x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
|
||||
x_chunk = torch.transpose(x_chunk, 1, 0)
|
||||
x_chunk = torch.reshape(x_chunk, [t, -1])
|
||||
x = torch.mm(x_rm_mask, x_chunk)
|
||||
x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
|
||||
|
||||
return x, x_len
|
||||
|
||||
def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def build_scama_mask_for_cross_attention_decoder(
|
||||
predictor_alignments: torch.Tensor,
|
||||
encoder_sequence_length: torch.Tensor,
|
||||
chunk_size: int = 5,
|
||||
encoder_chunk_size: int = 5,
|
||||
attention_chunk_center_bias: int = 0,
|
||||
attention_chunk_size: int = 1,
|
||||
attention_chunk_type: str = 'chunk',
|
||||
step=None,
|
||||
predictor_mask_chunk_hopping: torch.Tensor = None,
|
||||
decoder_att_look_back_factor: int = 1,
|
||||
mask_shift_att_chunk_decoder: torch.Tensor = None,
|
||||
target_length: torch.Tensor = None,
|
||||
is_training=True,
|
||||
dtype: torch.dtype = torch.float32):
|
||||
with torch.no_grad():
|
||||
device = predictor_alignments.device
|
||||
batch_size, chunk_num = predictor_alignments.size()
|
||||
maximum_encoder_length = encoder_sequence_length.max().item()
|
||||
int_type = predictor_alignments.dtype
|
||||
if not is_training:
|
||||
target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
|
||||
maximum_target_length = target_length.max()
|
||||
predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
|
||||
predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
|
||||
|
||||
|
||||
index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = index[:, :, None].repeat(1, 1, chunk_num)
|
||||
|
||||
index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
|
||||
index_div_bool_zeros = index_div == 0
|
||||
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
|
||||
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
|
||||
|
||||
index_div_bool_zeros_count *= chunk_size
|
||||
index_div_bool_zeros_count += attention_chunk_center_bias
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
|
||||
index_div_bool_zeros_count_ori = index_div_bool_zeros_count
|
||||
|
||||
index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
|
||||
max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
|
||||
|
||||
mask_flip, mask_flip2 = None, None
|
||||
if attention_chunk_size is not None:
|
||||
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
|
||||
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
|
||||
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
|
||||
mask_flip = 1 - index_div_bool_zeros_count_beg_mask
|
||||
attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
|
||||
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
|
||||
|
||||
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
|
||||
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
|
||||
mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
|
||||
|
||||
mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
|
||||
|
||||
if predictor_mask_chunk_hopping is not None:
|
||||
b, k, t = mask.size()
|
||||
predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
|
||||
|
||||
mask_mask_flip = mask
|
||||
if mask_flip is not None:
|
||||
mask_mask_flip = mask_flip * mask
|
||||
|
||||
def _fn():
|
||||
mask_sliced = mask[:b, :k, encoder_chunk_size:t]
|
||||
zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
|
||||
mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
|
||||
_, _, tt = predictor_mask_chunk_hopping.size()
|
||||
pad_right_p = max_len_chunk - tt
|
||||
predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
|
||||
masked = mask_sliced * predictor_mask_chunk_hopping_pad
|
||||
|
||||
mask_true = mask_mask_flip + masked
|
||||
return mask_true
|
||||
|
||||
mask = _fn() if t > chunk_size else mask_mask_flip
|
||||
|
||||
|
||||
|
||||
if mask_flip2 is not None:
|
||||
mask *= mask_flip2
|
||||
|
||||
mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
|
||||
mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
|
||||
|
||||
|
||||
|
||||
mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
|
||||
mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
|
||||
|
||||
|
||||
|
||||
|
||||
if attention_chunk_type == 'full':
|
||||
mask = torch.ones_like(mask).to(device)
|
||||
if mask_shift_att_chunk_decoder is not None:
|
||||
mask = mask * mask_shift_att_chunk_decoder
|
||||
mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
|
||||
|
||||
return mask
|
||||
|
||||
62
funasr_local/modules/streaming_utils/load_fr_tf.py
Normal file
62
funasr_local/modules/streaming_utils/load_fr_tf.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
import logging
|
||||
|
||||
def load_ckpt(checkpoint_path):
|
||||
import tensorflow as tf
|
||||
if tf.__version__.startswith('2'):
|
||||
import tensorflow.compat.v1 as tf
|
||||
tf.disable_v2_behavior()
|
||||
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
|
||||
else:
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||
|
||||
var_dict = dict()
|
||||
for var_name in sorted(var_to_shape_map):
|
||||
if "Adam" in var_name:
|
||||
continue
|
||||
tensor = reader.get_tensor(var_name)
|
||||
# print("in ckpt: {}, {}".format(var_name, tensor.shape))
|
||||
# print(tensor)
|
||||
var_dict[var_name] = tensor
|
||||
|
||||
return var_dict
|
||||
|
||||
|
||||
|
||||
def load_tf_pb_dict(pb_model):
|
||||
import tensorflow as tf
|
||||
if tf.__version__.startswith('2'):
|
||||
import tensorflow.compat.v1 as tf
|
||||
tf.disable_v2_behavior()
|
||||
# import tensorflow_addons as tfa
|
||||
# from tensorflow_addons.seq2seq.python.ops import beam_search_ops
|
||||
else:
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.python.ops import lookup_ops as lookup
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
sess = tf.Session()
|
||||
with gfile.FastGFile(pb_model, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
var_dict = dict()
|
||||
for node in sess.graph_def.node:
|
||||
if node.op == 'Const':
|
||||
value = tensor_util.MakeNdarray(node.attr['value'].tensor)
|
||||
if len(value.shape) >= 1:
|
||||
var_dict[node.name] = value
|
||||
return var_dict
|
||||
|
||||
def load_tf_dict(pb_model):
|
||||
if "model.ckpt-" in pb_model:
|
||||
var_dict = load_ckpt(pb_model)
|
||||
else:
|
||||
var_dict = load_tf_pb_dict(pb_model)
|
||||
return var_dict
|
||||
91
funasr_local/modules/streaming_utils/utils.py
Normal file
91
funasr_local/modules/streaming_utils/utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import yaml
|
||||
import numpy as np
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
mask = mask.detach()
|
||||
|
||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
||||
|
||||
def apply_cmvn(inputs, mvn):
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
meams = np.tile(mvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(mvn[1:2, :dim], (frame, 1))
|
||||
inputs -= torch.from_numpy(meams).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
|
||||
|
||||
|
||||
def drop_and_add(inputs: torch.Tensor,
|
||||
outputs: torch.Tensor,
|
||||
training: bool,
|
||||
dropout_rate: float = 0.1,
|
||||
stoch_layer_coeff: float = 1.0):
|
||||
|
||||
|
||||
|
||||
outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
|
||||
outputs *= stoch_layer_coeff
|
||||
|
||||
input_dim = inputs.size(-1)
|
||||
output_dim = outputs.size(-1)
|
||||
|
||||
if input_dim == output_dim:
|
||||
outputs += inputs
|
||||
return outputs
|
||||
|
||||
|
||||
def proc_tf_vocab(vocab_path):
|
||||
with open(vocab_path, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
if '<unk>' not in token_list:
|
||||
token_list.append('<unk>')
|
||||
return token_list
|
||||
|
||||
|
||||
def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
|
||||
token_list = proc_tf_vocab(vocab_path)
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config['token_list'] = token_list
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
|
||||
|
||||
|
||||
class NoAliasSafeDumper(yaml.SafeDumper):
|
||||
# Disable anchor/alias in yaml because looks ugly
|
||||
def ignore_aliases(self, data):
|
||||
return True
|
||||
|
||||
|
||||
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
|
||||
"""Safe-dump in yaml with no anchor/alias"""
|
||||
return yaml.dump(
|
||||
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
config_path = sys.argv[1]
|
||||
vocab_path = sys.argv[2]
|
||||
output_dir = sys.argv[3]
|
||||
gen_config_for_tfmodel(config_path, vocab_path, output_dir)
|
||||
611
funasr_local/modules/subsampling.py
Normal file
611
funasr_local/modules/subsampling.py
Normal file
@@ -0,0 +1,611 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Subsampling layer definition."""
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from funasr_local.modules.embedding import PositionalEncoding
|
||||
import logging
|
||||
from funasr_local.modules.streaming_utils.utils import sequence_mask
|
||||
from funasr_local.modules.nets_utils import sub_factor_to_params, pad_to_len
|
||||
from typing import Optional, Tuple, Union
|
||||
import math
|
||||
|
||||
class TooShortUttError(Exception):
|
||||
"""Raised when the utt is too short for subsampling.
|
||||
|
||||
Args:
|
||||
message (str): Message for error catch
|
||||
actual_size (int): the short size that cannot pass the subsampling
|
||||
limit (int): the limit size for subsampling
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, message, actual_size, limit):
|
||||
"""Construct a TooShortUttError for error handler."""
|
||||
super().__init__(message)
|
||||
self.actual_size = actual_size
|
||||
self.limit = limit
|
||||
|
||||
|
||||
def check_short_utt(ins, size):
|
||||
"""Check if the utterance is too short for subsampling."""
|
||||
if isinstance(ins, Conv2dSubsampling2) and size < 3:
|
||||
return True, 3
|
||||
if isinstance(ins, Conv2dSubsampling) and size < 7:
|
||||
return True, 7
|
||||
if isinstance(ins, Conv2dSubsampling6) and size < 11:
|
||||
return True, 11
|
||||
if isinstance(ins, Conv2dSubsampling8) and size < 15:
|
||||
return True, 15
|
||||
return False, -1
|
||||
|
||||
|
||||
class Conv2dSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 4.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 4.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-2:2]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get item.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
class Conv2dSubsamplingPad(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsamplingPad, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2, padding=(0, 0)),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2, padding=(0, 0)),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
self.pad_fn = torch.nn.ConstantPad1d((0, 4), 0.0)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 4.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 4.
|
||||
|
||||
"""
|
||||
x = x.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
x_len = torch.sum(x_mask[:, 0, :], dim=-1)
|
||||
x_len = (x_len - 1) // 2 + 1
|
||||
x_len = (x_len - 1) // 2 + 1
|
||||
mask = sequence_mask(x_len, None, x_len.dtype, x[0].device)
|
||||
return x, mask[:, None, :]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get item.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dSubsampling2(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling2 object."""
|
||||
super(Conv2dSubsampling2, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 2.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 2.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-2:1]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get item.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dSubsampling6(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/6 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling6 object."""
|
||||
super(Conv2dSubsampling6, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 5, 3),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 6.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 6.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
|
||||
|
||||
|
||||
class Conv2dSubsampling8(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/8 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling8 object."""
|
||||
super(Conv2dSubsampling8, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 8.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 8.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
||||
|
||||
class Conv1dSubsampling(torch.nn.Module):
|
||||
"""Convolutional 1D subsampling (to 1/2 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, kernel_size, stride, pad,
|
||||
tf2torch_tensor_name_prefix_torch: str = "stride_conv",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/proj_encoder/downsampling",
|
||||
):
|
||||
super(Conv1dSubsampling, self).__init__()
|
||||
self.conv = torch.nn.Conv1d(idim, odim, kernel_size, stride)
|
||||
self.pad_fn = torch.nn.ConstantPad1d(pad, 0.0)
|
||||
self.stride = stride
|
||||
self.odim = odim
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.odim
|
||||
|
||||
def forward(self, x, x_len):
|
||||
"""Subsample x.
|
||||
|
||||
"""
|
||||
x = x.transpose(1, 2) # (b, d ,t)
|
||||
x = self.pad_fn(x)
|
||||
x = F.relu(self.conv(x))
|
||||
x = x.transpose(1, 2) # (b, t ,d)
|
||||
|
||||
if x_len is None:
|
||||
|
||||
return x, None
|
||||
x_len = (x_len - 1) // self.stride + 1
|
||||
return x, x_len
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
## predictor
|
||||
"{}.conv.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
}, # (256,256,3),(3,256,256)
|
||||
"{}.conv.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
}
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
|
||||
var_dict_torch_update[name] = data_tf
|
||||
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
return var_dict_torch_update
|
||||
|
||||
class StreamingConvInput(torch.nn.Module):
|
||||
"""Streaming ConvInput module definition.
|
||||
Args:
|
||||
input_size: Input size.
|
||||
conv_size: Convolution size.
|
||||
subsampling_factor: Subsampling factor.
|
||||
vgg_like: Whether to use a VGG-like network.
|
||||
output_size: Block output dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
conv_size: Union[int, Tuple],
|
||||
subsampling_factor: int = 4,
|
||||
vgg_like: bool = True,
|
||||
output_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Construct a ConvInput object."""
|
||||
super().__init__()
|
||||
if vgg_like:
|
||||
if subsampling_factor == 1:
|
||||
conv_size1, conv_size2 = conv_size
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((1, 2)),
|
||||
torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((1, 2)),
|
||||
)
|
||||
|
||||
output_proj = conv_size2 * ((input_size // 2) // 2)
|
||||
|
||||
self.subsampling_factor = 1
|
||||
|
||||
self.stride_1 = 1
|
||||
|
||||
self.create_new_mask = self.create_new_vgg_mask
|
||||
|
||||
else:
|
||||
conv_size1, conv_size2 = conv_size
|
||||
|
||||
kernel_1 = int(subsampling_factor / 2)
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((kernel_1, 2)),
|
||||
torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((2, 2)),
|
||||
)
|
||||
|
||||
output_proj = conv_size2 * ((input_size // 2) // 2)
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
|
||||
self.create_new_mask = self.create_new_vgg_mask
|
||||
|
||||
self.stride_1 = kernel_1
|
||||
|
||||
else:
|
||||
if subsampling_factor == 1:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
|
||||
output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.kernel_2 = 3
|
||||
self.stride_2 = 1
|
||||
|
||||
self.create_new_mask = self.create_new_conv2d_mask
|
||||
|
||||
else:
|
||||
kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
|
||||
subsampling_factor,
|
||||
input_size,
|
||||
)
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
|
||||
output_proj = conv_size * conv_2_output_size
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.kernel_2 = kernel_2
|
||||
self.stride_2 = stride_2
|
||||
|
||||
self.create_new_mask = self.create_new_conv2d_mask
|
||||
|
||||
self.vgg_like = vgg_like
|
||||
self.min_frame_length = 7
|
||||
|
||||
if output_size is not None:
|
||||
self.output = torch.nn.Linear(output_proj, output_size)
|
||||
self.output_size = output_size
|
||||
else:
|
||||
self.output = None
|
||||
self.output_size = output_proj
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode input sequences.
|
||||
Args:
|
||||
x: ConvInput input sequences. (B, T, D_feats)
|
||||
mask: Mask of input sequences. (B, 1, T)
|
||||
Returns:
|
||||
x: ConvInput output sequences. (B, sub(T), D_out)
|
||||
mask: Mask of output sequences. (B, 1, sub(T))
|
||||
"""
|
||||
if mask is not None:
|
||||
mask = self.create_new_mask(mask)
|
||||
olens = max(mask.eq(0).sum(1))
|
||||
|
||||
b, t, f = x.size()
|
||||
x = x.unsqueeze(1) # (b. 1. t. f)
|
||||
|
||||
if chunk_size is not None:
|
||||
max_input_length = int(
|
||||
chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
|
||||
)
|
||||
x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
|
||||
x = list(x)
|
||||
x = torch.stack(x, dim=0)
|
||||
N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
|
||||
x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
_, c, _, f = x.size()
|
||||
if chunk_size is not None:
|
||||
x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
|
||||
else:
|
||||
x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
|
||||
|
||||
if self.output is not None:
|
||||
x = self.output(x)
|
||||
|
||||
return x, mask[:,:olens][:,:x.size(1)]
|
||||
|
||||
def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Create a new mask for VGG output sequences.
|
||||
Args:
|
||||
mask: Mask of input sequences. (B, T)
|
||||
Returns:
|
||||
mask: Mask of output sequences. (B, sub(T))
|
||||
"""
|
||||
if self.subsampling_factor > 1:
|
||||
vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
|
||||
mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
|
||||
|
||||
vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
|
||||
mask = mask[:, :vgg2_t_len][:, ::2]
|
||||
else:
|
||||
mask = mask
|
||||
|
||||
return mask
|
||||
|
||||
def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Create new conformer mask for Conv2d output sequences.
|
||||
Args:
|
||||
mask: Mask of input sequences. (B, T)
|
||||
Returns:
|
||||
mask: Mask of output sequences. (B, sub(T))
|
||||
"""
|
||||
if self.subsampling_factor > 1:
|
||||
return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
|
||||
else:
|
||||
return mask
|
||||
|
||||
def get_size_before_subsampling(self, size: int) -> int:
|
||||
"""Return the original size before subsampling for a given size.
|
||||
Args:
|
||||
size: Number of frames after subsampling.
|
||||
Returns:
|
||||
: Number of frames before subsampling.
|
||||
"""
|
||||
return size * self.subsampling_factor
|
||||
61
funasr_local/modules/subsampling_without_posenc.py
Normal file
61
funasr_local/modules/subsampling_without_posenc.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright 2020 Emiru Tsunoo
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Subsampling layer definition."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class Conv2dSubsamplingWOPosEnc(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
kernels (list): kernel sizes
|
||||
strides (list): stride sizes
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, kernels, strides):
|
||||
"""Construct an Conv2dSubsamplingWOPosEnc object."""
|
||||
assert len(kernels) == len(strides)
|
||||
super().__init__()
|
||||
conv = []
|
||||
olen = idim
|
||||
for i, (k, s) in enumerate(zip(kernels, strides)):
|
||||
conv += [
|
||||
torch.nn.Conv2d(1 if i == 0 else odim, odim, k, s),
|
||||
torch.nn.ReLU(),
|
||||
]
|
||||
olen = math.floor((olen - k) / s + 1)
|
||||
self.conv = torch.nn.Sequential(*conv)
|
||||
self.out = torch.nn.Linear(odim * olen, odim)
|
||||
self.strides = strides
|
||||
self.kernels = kernels
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 4.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 4.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
for k, s in zip(self.kernels, self.strides):
|
||||
x_mask = x_mask[:, :, : -k + 1 : s]
|
||||
return x, x_mask
|
||||
Reference in New Issue
Block a user