mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/modules/streaming_utils/__init__.py
Normal file
0
funasr_local/modules/streaming_utils/__init__.py
Normal file
390
funasr_local/modules/streaming_utils/chunk_utilis.py
Normal file
390
funasr_local/modules/streaming_utils/chunk_utilis.py
Normal file
@@ -0,0 +1,390 @@
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from funasr_local.modules.streaming_utils.utils import sequence_mask
|
||||
|
||||
|
||||
|
||||
class overlap_chunk():
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
chunk_size: tuple = (16,),
|
||||
stride: tuple = (10,),
|
||||
pad_left: tuple = (0,),
|
||||
encoder_att_look_back_factor: tuple = (1,),
|
||||
shfit_fsmn: int = 0,
|
||||
decoder_att_look_back_factor: tuple = (1,),
|
||||
):
|
||||
|
||||
pad_left = self.check_chunk_size_args(chunk_size, pad_left)
|
||||
encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
|
||||
decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
|
||||
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
|
||||
self.shfit_fsmn = shfit_fsmn
|
||||
self.x_add_mask = None
|
||||
self.x_rm_mask = None
|
||||
self.x_len = None
|
||||
self.mask_shfit_chunk = None
|
||||
self.mask_chunk_predictor = None
|
||||
self.mask_att_chunk_encoder = None
|
||||
self.mask_shift_att_chunk_decoder = None
|
||||
self.chunk_outs = None
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
|
||||
= None, None, None, None, None
|
||||
|
||||
def check_chunk_size_args(self, chunk_size, x):
|
||||
if len(x) < len(chunk_size):
|
||||
x = [x[0] for i in chunk_size]
|
||||
return x
|
||||
|
||||
def get_chunk_size(self,
|
||||
ind: int = 0
|
||||
):
|
||||
# with torch.no_grad:
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
|
||||
self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
|
||||
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
|
||||
|
||||
def random_choice(self, training=True, decoding_ind=None):
|
||||
chunk_num = len(self.chunk_size)
|
||||
ind = 0
|
||||
if training and chunk_num > 1:
|
||||
ind = torch.randint(0, chunk_num-1, ()).cpu().item()
|
||||
if not training and decoding_ind is not None:
|
||||
ind = int(decoding_ind)
|
||||
|
||||
return ind
|
||||
|
||||
|
||||
|
||||
|
||||
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
|
||||
|
||||
with torch.no_grad():
|
||||
x_len = x_len.cpu().numpy()
|
||||
x_len_max = x_len.max()
|
||||
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
|
||||
shfit_fsmn = self.shfit_fsmn
|
||||
pad_right = chunk_size - stride - pad_left
|
||||
|
||||
chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
|
||||
x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
|
||||
x_len_chunk = x_len_chunk.astype(x_len.dtype)
|
||||
x_len_chunk_max = x_len_chunk.max()
|
||||
|
||||
chunk_num = int(math.ceil(x_len_max/stride))
|
||||
dtype = np.int32
|
||||
max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
|
||||
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
|
||||
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
|
||||
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
|
||||
mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
|
||||
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
|
||||
mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
|
||||
for chunk_ids in range(chunk_num):
|
||||
# x_mask add
|
||||
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
|
||||
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
|
||||
x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
|
||||
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
|
||||
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
|
||||
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
|
||||
|
||||
# x_mask rm
|
||||
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
|
||||
padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
|
||||
padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
|
||||
x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
|
||||
x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
|
||||
x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
|
||||
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
|
||||
x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
|
||||
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
|
||||
|
||||
# fsmn_padding_mask
|
||||
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
|
||||
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
|
||||
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
|
||||
mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
|
||||
|
||||
# predictor mask
|
||||
zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
|
||||
zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
|
||||
mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
|
||||
mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
|
||||
|
||||
# encoder att mask
|
||||
zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
|
||||
zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
|
||||
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_2_mid = np.ones([stride, stride], dtype=dtype)
|
||||
zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
|
||||
zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
|
||||
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
|
||||
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
|
||||
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
|
||||
|
||||
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
|
||||
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
|
||||
|
||||
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
|
||||
zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
|
||||
mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
|
||||
mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
|
||||
|
||||
|
||||
# decoder fsmn_shift_att_mask
|
||||
zeros_1 = np.zeros([shfit_fsmn, 1])
|
||||
ones_1 = np.ones([chunk_size, 1])
|
||||
mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
|
||||
mask_shift_att_chunk_decoder = np.concatenate(
|
||||
[mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
|
||||
|
||||
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
|
||||
self.x_len_chunk = x_len_chunk
|
||||
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
|
||||
self.x_len = x_len
|
||||
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
|
||||
self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
|
||||
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
|
||||
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
|
||||
self.chunk_outs = (self.x_add_mask,
|
||||
self.x_len_chunk,
|
||||
self.x_rm_mask,
|
||||
self.x_len,
|
||||
self.mask_shfit_chunk,
|
||||
self.mask_chunk_predictor,
|
||||
self.mask_att_chunk_encoder,
|
||||
self.mask_shift_att_chunk_decoder)
|
||||
|
||||
return self.chunk_outs
|
||||
|
||||
|
||||
def split_chunk(self, x, x_len, chunk_outs):
|
||||
"""
|
||||
:param x: (b, t, d)
|
||||
:param x_length: (b)
|
||||
:param ind: int
|
||||
:return:
|
||||
"""
|
||||
x = x[:, :x_len.max(), :]
|
||||
b, t, d = x.size()
|
||||
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
|
||||
x.device)
|
||||
x *= x_len_mask[:, :, None]
|
||||
|
||||
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
|
||||
x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
|
||||
pad = (0, 0, self.pad_left_cur, 0)
|
||||
x = F.pad(x, pad, "constant", 0.0)
|
||||
b, t, d = x.size()
|
||||
x = torch.transpose(x, 1, 0)
|
||||
x = torch.reshape(x, [t, -1])
|
||||
x_chunk = torch.mm(x_add_mask, x)
|
||||
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
|
||||
|
||||
return x_chunk, x_len_chunk
|
||||
|
||||
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
|
||||
x_chunk = x_chunk[:, :x_len_chunk.max(), :]
|
||||
b, t, d = x_chunk.size()
|
||||
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
|
||||
x_chunk.device)
|
||||
x_chunk *= x_len_chunk_mask[:, :, None]
|
||||
|
||||
x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
|
||||
x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
|
||||
x_chunk = torch.transpose(x_chunk, 1, 0)
|
||||
x_chunk = torch.reshape(x_chunk, [t, -1])
|
||||
x = torch.mm(x_rm_mask, x_chunk)
|
||||
x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
|
||||
|
||||
return x, x_len
|
||||
|
||||
def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def build_scama_mask_for_cross_attention_decoder(
|
||||
predictor_alignments: torch.Tensor,
|
||||
encoder_sequence_length: torch.Tensor,
|
||||
chunk_size: int = 5,
|
||||
encoder_chunk_size: int = 5,
|
||||
attention_chunk_center_bias: int = 0,
|
||||
attention_chunk_size: int = 1,
|
||||
attention_chunk_type: str = 'chunk',
|
||||
step=None,
|
||||
predictor_mask_chunk_hopping: torch.Tensor = None,
|
||||
decoder_att_look_back_factor: int = 1,
|
||||
mask_shift_att_chunk_decoder: torch.Tensor = None,
|
||||
target_length: torch.Tensor = None,
|
||||
is_training=True,
|
||||
dtype: torch.dtype = torch.float32):
|
||||
with torch.no_grad():
|
||||
device = predictor_alignments.device
|
||||
batch_size, chunk_num = predictor_alignments.size()
|
||||
maximum_encoder_length = encoder_sequence_length.max().item()
|
||||
int_type = predictor_alignments.dtype
|
||||
if not is_training:
|
||||
target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
|
||||
maximum_target_length = target_length.max()
|
||||
predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
|
||||
predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
|
||||
|
||||
|
||||
index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = index[:, :, None].repeat(1, 1, chunk_num)
|
||||
|
||||
index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
|
||||
index_div_bool_zeros = index_div == 0
|
||||
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
|
||||
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
|
||||
|
||||
index_div_bool_zeros_count *= chunk_size
|
||||
index_div_bool_zeros_count += attention_chunk_center_bias
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
|
||||
index_div_bool_zeros_count_ori = index_div_bool_zeros_count
|
||||
|
||||
index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
|
||||
max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
|
||||
|
||||
mask_flip, mask_flip2 = None, None
|
||||
if attention_chunk_size is not None:
|
||||
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
|
||||
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
|
||||
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
|
||||
mask_flip = 1 - index_div_bool_zeros_count_beg_mask
|
||||
attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
|
||||
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
|
||||
|
||||
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
|
||||
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
|
||||
mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
|
||||
|
||||
mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
|
||||
|
||||
if predictor_mask_chunk_hopping is not None:
|
||||
b, k, t = mask.size()
|
||||
predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
|
||||
|
||||
mask_mask_flip = mask
|
||||
if mask_flip is not None:
|
||||
mask_mask_flip = mask_flip * mask
|
||||
|
||||
def _fn():
|
||||
mask_sliced = mask[:b, :k, encoder_chunk_size:t]
|
||||
zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
|
||||
mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
|
||||
_, _, tt = predictor_mask_chunk_hopping.size()
|
||||
pad_right_p = max_len_chunk - tt
|
||||
predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
|
||||
masked = mask_sliced * predictor_mask_chunk_hopping_pad
|
||||
|
||||
mask_true = mask_mask_flip + masked
|
||||
return mask_true
|
||||
|
||||
mask = _fn() if t > chunk_size else mask_mask_flip
|
||||
|
||||
|
||||
|
||||
if mask_flip2 is not None:
|
||||
mask *= mask_flip2
|
||||
|
||||
mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
|
||||
mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
|
||||
|
||||
|
||||
|
||||
mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
|
||||
mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
|
||||
|
||||
|
||||
|
||||
|
||||
if attention_chunk_type == 'full':
|
||||
mask = torch.ones_like(mask).to(device)
|
||||
if mask_shift_att_chunk_decoder is not None:
|
||||
mask = mask * mask_shift_att_chunk_decoder
|
||||
mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
|
||||
|
||||
return mask
|
||||
|
||||
62
funasr_local/modules/streaming_utils/load_fr_tf.py
Normal file
62
funasr_local/modules/streaming_utils/load_fr_tf.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
import logging
|
||||
|
||||
def load_ckpt(checkpoint_path):
|
||||
import tensorflow as tf
|
||||
if tf.__version__.startswith('2'):
|
||||
import tensorflow.compat.v1 as tf
|
||||
tf.disable_v2_behavior()
|
||||
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
|
||||
else:
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||
|
||||
var_dict = dict()
|
||||
for var_name in sorted(var_to_shape_map):
|
||||
if "Adam" in var_name:
|
||||
continue
|
||||
tensor = reader.get_tensor(var_name)
|
||||
# print("in ckpt: {}, {}".format(var_name, tensor.shape))
|
||||
# print(tensor)
|
||||
var_dict[var_name] = tensor
|
||||
|
||||
return var_dict
|
||||
|
||||
|
||||
|
||||
def load_tf_pb_dict(pb_model):
|
||||
import tensorflow as tf
|
||||
if tf.__version__.startswith('2'):
|
||||
import tensorflow.compat.v1 as tf
|
||||
tf.disable_v2_behavior()
|
||||
# import tensorflow_addons as tfa
|
||||
# from tensorflow_addons.seq2seq.python.ops import beam_search_ops
|
||||
else:
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.python.ops import lookup_ops as lookup
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
sess = tf.Session()
|
||||
with gfile.FastGFile(pb_model, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
|
||||
var_dict = dict()
|
||||
for node in sess.graph_def.node:
|
||||
if node.op == 'Const':
|
||||
value = tensor_util.MakeNdarray(node.attr['value'].tensor)
|
||||
if len(value.shape) >= 1:
|
||||
var_dict[node.name] = value
|
||||
return var_dict
|
||||
|
||||
def load_tf_dict(pb_model):
|
||||
if "model.ckpt-" in pb_model:
|
||||
var_dict = load_ckpt(pb_model)
|
||||
else:
|
||||
var_dict = load_tf_pb_dict(pb_model)
|
||||
return var_dict
|
||||
91
funasr_local/modules/streaming_utils/utils.py
Normal file
91
funasr_local/modules/streaming_utils/utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import yaml
|
||||
import numpy as np
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
mask = mask.detach()
|
||||
|
||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
||||
|
||||
def apply_cmvn(inputs, mvn):
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
meams = np.tile(mvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(mvn[1:2, :dim], (frame, 1))
|
||||
inputs -= torch.from_numpy(meams).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
|
||||
|
||||
|
||||
def drop_and_add(inputs: torch.Tensor,
|
||||
outputs: torch.Tensor,
|
||||
training: bool,
|
||||
dropout_rate: float = 0.1,
|
||||
stoch_layer_coeff: float = 1.0):
|
||||
|
||||
|
||||
|
||||
outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
|
||||
outputs *= stoch_layer_coeff
|
||||
|
||||
input_dim = inputs.size(-1)
|
||||
output_dim = outputs.size(-1)
|
||||
|
||||
if input_dim == output_dim:
|
||||
outputs += inputs
|
||||
return outputs
|
||||
|
||||
|
||||
def proc_tf_vocab(vocab_path):
|
||||
with open(vocab_path, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
if '<unk>' not in token_list:
|
||||
token_list.append('<unk>')
|
||||
return token_list
|
||||
|
||||
|
||||
def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
|
||||
token_list = proc_tf_vocab(vocab_path)
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config['token_list'] = token_list
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
|
||||
|
||||
|
||||
class NoAliasSafeDumper(yaml.SafeDumper):
|
||||
# Disable anchor/alias in yaml because looks ugly
|
||||
def ignore_aliases(self, data):
|
||||
return True
|
||||
|
||||
|
||||
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
|
||||
"""Safe-dump in yaml with no anchor/alias"""
|
||||
return yaml.dump(
|
||||
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
config_path = sys.argv[1]
|
||||
vocab_path = sys.argv[2]
|
||||
output_dir = sys.argv[3]
|
||||
gen_config_for_tfmodel(config_path, vocab_path, output_dir)
|
||||
Reference in New Issue
Block a user