mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
162
funasr_local/export/models/CT_Transformer.py
Normal file
162
funasr_local/export/models/CT_Transformer.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr_local.models.encoder.sanm_encoder import SANMEncoder
|
||||
from funasr_local.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
|
||||
from funasr_local.models.encoder.sanm_encoder import SANMVadEncoder
|
||||
from funasr_local.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
|
||||
|
||||
class CT_Transformer(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
model_name='punc_model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
self.embed = model.embed
|
||||
self.decoder = model.decoder
|
||||
# self.model = model
|
||||
self.feats_dim = self.embed.embedding_dim
|
||||
self.num_embeddings = self.embed.num_embeddings
|
||||
self.model_name = model_name
|
||||
|
||||
if isinstance(model.encoder, SANMEncoder):
|
||||
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
|
||||
else:
|
||||
assert False, "Only support samn encode."
|
||||
|
||||
def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(inputs)
|
||||
# mask = self._target_mask(input)
|
||||
h, _ = self.encoder(x, text_lengths)
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
length = 120
|
||||
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
|
||||
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
|
||||
return (text_indexes, text_lengths)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['inputs', 'text_lengths']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'inputs': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
'text_lengths': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'logits': {
|
||||
0: 'batch_size',
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CT_Transformer_VadRealtime(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
model_name='punc_model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
|
||||
self.embed = model.embed
|
||||
if isinstance(model.encoder, SANMVadEncoder):
|
||||
self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
|
||||
else:
|
||||
assert False, "Only support samn encode."
|
||||
self.decoder = model.decoder
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
|
||||
def forward(self, inputs: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
sub_masks: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(inputs)
|
||||
# mask = self._target_mask(input)
|
||||
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def with_vad(self):
|
||||
return True
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
length = 120
|
||||
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
|
||||
text_lengths = torch.tensor([length], dtype=torch.int32)
|
||||
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
|
||||
sub_masks = torch.ones(length, length, dtype=torch.float32)
|
||||
sub_masks = torch.tril(sub_masks).type(torch.float32)
|
||||
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
|
||||
|
||||
def get_input_names(self):
|
||||
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'inputs': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
'vad_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'sub_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'logits': {
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
25
funasr_local/export/models/__init__.py
Normal file
25
funasr_local/export/models/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from funasr_local.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
|
||||
from funasr_local.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
||||
from funasr_local.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
|
||||
from funasr_local.models.e2e_vad import E2EVadModel
|
||||
from funasr_local.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
|
||||
from funasr_local.models.target_delay_transformer import TargetDelayTransformer
|
||||
from funasr_local.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
|
||||
from funasr_local.train.abs_model import PunctuationModel
|
||||
from funasr_local.models.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr_local.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
|
||||
|
||||
def get_model(model, export_config=None):
|
||||
if isinstance(model, BiCifParaformer):
|
||||
return BiCifParaformer_export(model, **export_config)
|
||||
elif isinstance(model, Paraformer):
|
||||
return Paraformer_export(model, **export_config)
|
||||
elif isinstance(model, E2EVadModel):
|
||||
return E2EVadModel_export(model, **export_config)
|
||||
elif isinstance(model, PunctuationModel):
|
||||
if isinstance(model.punc_model, TargetDelayTransformer):
|
||||
return CT_Transformer_export(model.punc_model, **export_config)
|
||||
elif isinstance(model.punc_model, VadRealtimeTransformer):
|
||||
return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
|
||||
else:
|
||||
raise "Funasr does not support the given model type currently."
|
||||
0
funasr_local/export/models/decoder/__init__.py
Normal file
0
funasr_local/export/models/decoder/__init__.py
Normal file
159
funasr_local/export/models/decoder/sanm_decoder.py
Normal file
159
funasr_local/export/models/decoder/sanm_decoder.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from funasr_local.export.utils.torch_function import MakePadMask
|
||||
from funasr_local.export.utils.torch_function import sequence_mask
|
||||
|
||||
from funasr_local.modules.attention import MultiHeadedAttentionSANMDecoder
|
||||
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
|
||||
from funasr_local.modules.attention import MultiHeadedAttentionCrossAtt
|
||||
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
|
||||
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
||||
from funasr_local.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
|
||||
from funasr_local.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
|
||||
|
||||
|
||||
class ParaformerSANMDecoder(nn.Module):
|
||||
def __init__(self, model,
|
||||
max_seq_len=512,
|
||||
model_name='decoder',
|
||||
onnx: bool = True,):
|
||||
super().__init__()
|
||||
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
|
||||
self.model = model
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
for i, d in enumerate(self.model.decoders):
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
|
||||
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
|
||||
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
|
||||
if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
|
||||
d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
|
||||
self.model.decoders[i] = DecoderLayerSANM_export(d)
|
||||
|
||||
if self.model.decoders2 is not None:
|
||||
for i, d in enumerate(self.model.decoders2):
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
|
||||
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
|
||||
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
|
||||
self.model.decoders2[i] = DecoderLayerSANM_export(d)
|
||||
|
||||
for i, d in enumerate(self.model.decoders3):
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
|
||||
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
|
||||
self.model.decoders3[i] = DecoderLayerSANM_export(d)
|
||||
|
||||
self.output_layer = model.output_layer
|
||||
self.after_norm = model.after_norm
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
def prepare_mask(self, mask):
|
||||
mask_3d_btd = mask[:, :, None]
|
||||
if len(mask.shape) == 2:
|
||||
mask_4d_bhlt = 1 - mask[:, None, None, :]
|
||||
elif len(mask.shape) == 3:
|
||||
mask_4d_bhlt = 1 - mask[:, None, :]
|
||||
mask_4d_bhlt = mask_4d_bhlt * -10000.0
|
||||
|
||||
return mask_3d_btd, mask_4d_bhlt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
):
|
||||
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = self.make_pad_mask(ys_in_lens)
|
||||
tgt_mask, _ = self.prepare_mask(tgt_mask)
|
||||
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = self.make_pad_mask(hlens)
|
||||
_, memory_mask = self.prepare_mask(memory_mask)
|
||||
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
|
||||
|
||||
x = tgt
|
||||
x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if self.model.decoders2 is not None:
|
||||
x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
x = self.after_norm(x)
|
||||
x = self.output_layer(x)
|
||||
|
||||
return x, ys_in_lens
|
||||
|
||||
|
||||
def get_dummy_inputs(self, enc_size):
|
||||
tgt = torch.LongTensor([0]).unsqueeze(0)
|
||||
memory = torch.randn(1, 100, enc_size)
|
||||
pre_acoustic_embeds = torch.randn(1, 1, enc_size)
|
||||
cache_num = len(self.model.decoders) + len(self.model.decoders2)
|
||||
cache = [
|
||||
torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
|
||||
for _ in range(cache_num)
|
||||
]
|
||||
return (tgt, memory, pre_acoustic_embeds, cache)
|
||||
|
||||
def is_optimizable(self):
|
||||
return True
|
||||
|
||||
def get_input_names(self):
|
||||
cache_num = len(self.model.decoders) + len(self.model.decoders2)
|
||||
return ['tgt', 'memory', 'pre_acoustic_embeds'] \
|
||||
+ ['cache_%d' % i for i in range(cache_num)]
|
||||
|
||||
def get_output_names(self):
|
||||
cache_num = len(self.model.decoders) + len(self.model.decoders2)
|
||||
return ['y'] \
|
||||
+ ['out_cache_%d' % i for i in range(cache_num)]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
ret = {
|
||||
'tgt': {
|
||||
0: 'tgt_batch',
|
||||
1: 'tgt_length'
|
||||
},
|
||||
'memory': {
|
||||
0: 'memory_batch',
|
||||
1: 'memory_length'
|
||||
},
|
||||
'pre_acoustic_embeds': {
|
||||
0: 'acoustic_embeds_batch',
|
||||
1: 'acoustic_embeds_length',
|
||||
}
|
||||
}
|
||||
cache_num = len(self.model.decoders) + len(self.model.decoders2)
|
||||
ret.update({
|
||||
'cache_%d' % d: {
|
||||
0: 'cache_%d_batch' % d,
|
||||
2: 'cache_%d_length' % d
|
||||
}
|
||||
for d in range(cache_num)
|
||||
})
|
||||
return ret
|
||||
|
||||
def get_model_config(self, path):
|
||||
return {
|
||||
"dec_type": "XformerDecoder",
|
||||
"model_path": os.path.join(path, f'{self.model_name}.onnx'),
|
||||
"n_layers": len(self.model.decoders) + len(self.model.decoders2),
|
||||
"odim": self.model.decoders[0].size
|
||||
}
|
||||
143
funasr_local/export/models/decoder/transformer_decoder.py
Normal file
143
funasr_local/export/models/decoder/transformer_decoder.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
from funasr_local.export import models
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from funasr_local.export.utils.torch_function import MakePadMask
|
||||
from funasr_local.export.utils.torch_function import sequence_mask
|
||||
|
||||
from funasr_local.modules.attention import MultiHeadedAttentionSANMDecoder
|
||||
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
|
||||
from funasr_local.modules.attention import MultiHeadedAttentionCrossAtt, MultiHeadedAttention
|
||||
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
|
||||
from funasr_local.export.models.modules.multihead_att import OnnxMultiHeadedAttention
|
||||
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
||||
from funasr_local.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
|
||||
from funasr_local.export.models.modules.decoder_layer import DecoderLayer as DecoderLayer_export
|
||||
|
||||
|
||||
class ParaformerDecoderSAN(nn.Module):
|
||||
def __init__(self, model,
|
||||
max_seq_len=512,
|
||||
model_name='decoder',
|
||||
onnx: bool = True,):
|
||||
super().__init__()
|
||||
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
|
||||
self.model = model
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
for i, d in enumerate(self.model.decoders):
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
|
||||
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
|
||||
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
|
||||
# if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
|
||||
# d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
|
||||
if isinstance(d.src_attn, MultiHeadedAttention):
|
||||
d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
|
||||
self.model.decoders[i] = DecoderLayer_export(d)
|
||||
|
||||
self.output_layer = model.output_layer
|
||||
self.after_norm = model.after_norm
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
def prepare_mask(self, mask):
|
||||
mask_3d_btd = mask[:, :, None]
|
||||
if len(mask.shape) == 2:
|
||||
mask_4d_bhlt = 1 - mask[:, None, None, :]
|
||||
elif len(mask.shape) == 3:
|
||||
mask_4d_bhlt = 1 - mask[:, None, :]
|
||||
mask_4d_bhlt = mask_4d_bhlt * -10000.0
|
||||
|
||||
return mask_3d_btd, mask_4d_bhlt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
):
|
||||
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = self.make_pad_mask(ys_in_lens)
|
||||
tgt_mask, _ = self.prepare_mask(tgt_mask)
|
||||
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = self.make_pad_mask(hlens)
|
||||
_, memory_mask = self.prepare_mask(memory_mask)
|
||||
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
|
||||
|
||||
x = tgt
|
||||
x, tgt_mask, memory, memory_mask = self.model.decoders(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
x = self.after_norm(x)
|
||||
x = self.output_layer(x)
|
||||
|
||||
return x, ys_in_lens
|
||||
|
||||
|
||||
def get_dummy_inputs(self, enc_size):
|
||||
tgt = torch.LongTensor([0]).unsqueeze(0)
|
||||
memory = torch.randn(1, 100, enc_size)
|
||||
pre_acoustic_embeds = torch.randn(1, 1, enc_size)
|
||||
cache_num = len(self.model.decoders) + len(self.model.decoders2)
|
||||
cache = [
|
||||
torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
|
||||
for _ in range(cache_num)
|
||||
]
|
||||
return (tgt, memory, pre_acoustic_embeds, cache)
|
||||
|
||||
def is_optimizable(self):
|
||||
return True
|
||||
|
||||
def get_input_names(self):
|
||||
cache_num = len(self.model.decoders) + len(self.model.decoders2)
|
||||
return ['tgt', 'memory', 'pre_acoustic_embeds'] \
|
||||
+ ['cache_%d' % i for i in range(cache_num)]
|
||||
|
||||
def get_output_names(self):
|
||||
cache_num = len(self.model.decoders) + len(self.model.decoders2)
|
||||
return ['y'] \
|
||||
+ ['out_cache_%d' % i for i in range(cache_num)]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
ret = {
|
||||
'tgt': {
|
||||
0: 'tgt_batch',
|
||||
1: 'tgt_length'
|
||||
},
|
||||
'memory': {
|
||||
0: 'memory_batch',
|
||||
1: 'memory_length'
|
||||
},
|
||||
'pre_acoustic_embeds': {
|
||||
0: 'acoustic_embeds_batch',
|
||||
1: 'acoustic_embeds_length',
|
||||
}
|
||||
}
|
||||
cache_num = len(self.model.decoders) + len(self.model.decoders2)
|
||||
ret.update({
|
||||
'cache_%d' % d: {
|
||||
0: 'cache_%d_batch' % d,
|
||||
2: 'cache_%d_length' % d
|
||||
}
|
||||
for d in range(cache_num)
|
||||
})
|
||||
return ret
|
||||
|
||||
def get_model_config(self, path):
|
||||
return {
|
||||
"dec_type": "XformerDecoder",
|
||||
"model_path": os.path.join(path, f'{self.model_name}.onnx'),
|
||||
"n_layers": len(self.model.decoders) + len(self.model.decoders2),
|
||||
"odim": self.model.decoders[0].size
|
||||
}
|
||||
219
funasr_local/export/models/e2e_asr_paraformer.py
Normal file
219
funasr_local/export/models/e2e_asr_paraformer.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr_local.export.utils.torch_function import MakePadMask
|
||||
from funasr_local.export.utils.torch_function import sequence_mask
|
||||
from funasr_local.models.encoder.sanm_encoder import SANMEncoder
|
||||
from funasr_local.models.encoder.conformer_encoder import ConformerEncoder
|
||||
from funasr_local.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
|
||||
from funasr_local.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
|
||||
from funasr_local.models.predictor.cif import CifPredictorV2, CifPredictorV3
|
||||
from funasr_local.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
|
||||
from funasr_local.export.models.predictor.cif import CifPredictorV3 as CifPredictorV3_export
|
||||
from funasr_local.models.decoder.sanm_decoder import ParaformerSANMDecoder
|
||||
from funasr_local.models.decoder.transformer_decoder import ParaformerDecoderSAN
|
||||
from funasr_local.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
|
||||
from funasr_local.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
|
||||
|
||||
|
||||
class Paraformer(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
if isinstance(model.encoder, SANMEncoder):
|
||||
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
|
||||
elif isinstance(model.encoder, ConformerEncoder):
|
||||
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
|
||||
if isinstance(model.predictor, CifPredictorV2):
|
||||
self.predictor = CifPredictorV2_export(model.predictor)
|
||||
if isinstance(model.decoder, ParaformerSANMDecoder):
|
||||
self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
|
||||
elif isinstance(model.decoder, ParaformerDecoderSAN):
|
||||
self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
|
||||
|
||||
self.feats_dim = feats_dim
|
||||
self.model_name = model_name
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
# a. To device
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
# batch = to_device(batch, device=self.device)
|
||||
|
||||
enc, enc_len = self.encoder(**batch)
|
||||
mask = self.make_pad_mask(enc_len)[:, None, :]
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
|
||||
pre_token_length = pre_token_length.floor().type(torch.int32)
|
||||
|
||||
decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
# sample_ids = decoder_out.argmax(dim=-1)
|
||||
|
||||
return decoder_out, pre_token_length
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
speech = torch.randn(2, 30, self.feats_dim)
|
||||
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
|
||||
return (speech, speech_lengths)
|
||||
|
||||
def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
|
||||
import numpy as np
|
||||
fbank = np.loadtxt(txt_file)
|
||||
fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
|
||||
speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
|
||||
speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
|
||||
return (speech, speech_lengths)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['speech', 'speech_lengths']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits', 'token_num']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'speech': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
'speech_lengths': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'logits': {
|
||||
0: 'batch_size',
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class BiCifParaformer(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
if isinstance(model.encoder, SANMEncoder):
|
||||
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
|
||||
elif isinstance(model.encoder, ConformerEncoder):
|
||||
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
|
||||
else:
|
||||
logging.warning("Unsupported encoder type to export.")
|
||||
if isinstance(model.predictor, CifPredictorV3):
|
||||
self.predictor = CifPredictorV3_export(model.predictor)
|
||||
else:
|
||||
logging.warning("Wrong predictor type to export.")
|
||||
if isinstance(model.decoder, ParaformerSANMDecoder):
|
||||
self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
|
||||
elif isinstance(model.decoder, ParaformerDecoderSAN):
|
||||
self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
|
||||
else:
|
||||
logging.warning("Unsupported decoder type to export.")
|
||||
|
||||
self.feats_dim = feats_dim
|
||||
self.model_name = model_name
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
# a. To device
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
# batch = to_device(batch, device=self.device)
|
||||
|
||||
enc, enc_len = self.encoder(**batch)
|
||||
mask = self.make_pad_mask(enc_len)[:, None, :]
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
|
||||
pre_token_length = pre_token_length.round().type(torch.int32)
|
||||
|
||||
decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
|
||||
# get predicted timestamps
|
||||
us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
|
||||
|
||||
return decoder_out, pre_token_length, us_alphas, us_cif_peak
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
speech = torch.randn(2, 30, self.feats_dim)
|
||||
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
|
||||
return (speech, speech_lengths)
|
||||
|
||||
def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
|
||||
import numpy as np
|
||||
fbank = np.loadtxt(txt_file)
|
||||
fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
|
||||
speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
|
||||
speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
|
||||
return (speech, speech_lengths)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['speech', 'speech_lengths']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'speech': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
'speech_lengths': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'logits': {
|
||||
0: 'batch_size',
|
||||
1: 'logits_length'
|
||||
},
|
||||
'us_alphas': {
|
||||
0: 'batch_size',
|
||||
1: 'alphas_length'
|
||||
},
|
||||
'us_cif_peak': {
|
||||
0: 'batch_size',
|
||||
1: 'alphas_length'
|
||||
},
|
||||
}
|
||||
60
funasr_local/export/models/e2e_vad.py
Normal file
60
funasr_local/export/models/e2e_vad.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
from funasr_local.models.encoder.fsmn_encoder import FSMN
|
||||
from funasr_local.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
|
||||
|
||||
class E2EVadModel(nn.Module):
|
||||
def __init__(self, model,
|
||||
max_seq_len=512,
|
||||
feats_dim=400,
|
||||
model_name='model',
|
||||
**kwargs,):
|
||||
super(E2EVadModel, self).__init__()
|
||||
self.feats_dim = feats_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.model_name = model_name
|
||||
if isinstance(model.encoder, FSMN):
|
||||
self.encoder = FSMN_export(model.encoder)
|
||||
else:
|
||||
raise "unsupported encoder"
|
||||
|
||||
|
||||
def forward(self, feats: torch.Tensor, *args, ):
|
||||
|
||||
scores, out_caches = self.encoder(feats, *args)
|
||||
return scores, out_caches
|
||||
|
||||
def get_dummy_inputs(self, frame=30):
|
||||
speech = torch.randn(1, frame, self.feats_dim)
|
||||
in_cache0 = torch.randn(1, 128, 19, 1)
|
||||
in_cache1 = torch.randn(1, 128, 19, 1)
|
||||
in_cache2 = torch.randn(1, 128, 19, 1)
|
||||
in_cache3 = torch.randn(1, 128, 19, 1)
|
||||
|
||||
return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
|
||||
|
||||
# def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
|
||||
# import numpy as np
|
||||
# fbank = np.loadtxt(txt_file)
|
||||
# fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
|
||||
# speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
|
||||
# speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
|
||||
# return (speech, speech_lengths)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'speech': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
}
|
||||
0
funasr_local/export/models/encoder/__init__.py
Normal file
0
funasr_local/export/models/encoder/__init__.py
Normal file
105
funasr_local/export/models/encoder/conformer_encoder.py
Normal file
105
funasr_local/export/models/encoder/conformer_encoder.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr_local.export.utils.torch_function import MakePadMask
|
||||
from funasr_local.export.utils.torch_function import sequence_mask
|
||||
from funasr_local.modules.attention import MultiHeadedAttentionSANM
|
||||
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
|
||||
from funasr_local.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
|
||||
from funasr_local.export.models.modules.encoder_layer import EncoderLayerConformer as EncoderLayerConformer_export
|
||||
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from funasr_local.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
|
||||
from funasr_local.export.models.encoder.sanm_encoder import SANMEncoder
|
||||
from funasr_local.modules.attention import RelPositionMultiHeadedAttention
|
||||
# from funasr_local.export.models.modules.multihead_att import RelPositionMultiHeadedAttention as RelPositionMultiHeadedAttention_export
|
||||
from funasr_local.export.models.modules.multihead_att import OnnxRelPosMultiHeadedAttention as RelPositionMultiHeadedAttention_export
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='encoder',
|
||||
onnx: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed = model.embed
|
||||
self.model = model
|
||||
self.feats_dim = feats_dim
|
||||
self._output_size = model._output_size
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
for i, d in enumerate(self.model.encoders):
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
|
||||
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
|
||||
if isinstance(d.self_attn, RelPositionMultiHeadedAttention):
|
||||
d.self_attn = RelPositionMultiHeadedAttention_export(d.self_attn)
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForward):
|
||||
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
|
||||
self.model.encoders[i] = EncoderLayerConformer_export(d)
|
||||
|
||||
self.model_name = model_name
|
||||
self.num_heads = model.encoders[0].self_attn.h
|
||||
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
|
||||
|
||||
|
||||
def prepare_mask(self, mask):
|
||||
if len(mask.shape) == 2:
|
||||
mask = 1 - mask[:, None, None, :]
|
||||
elif len(mask.shape) == 3:
|
||||
mask = 1 - mask[:, None, :]
|
||||
|
||||
return mask * -10000.0
|
||||
|
||||
def forward(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
mask = self.make_pad_mask(speech_lengths)
|
||||
mask = self.prepare_mask(mask)
|
||||
if self.embed is None:
|
||||
xs_pad = speech
|
||||
else:
|
||||
xs_pad = self.embed(speech)
|
||||
|
||||
encoder_outs = self.model.encoders(xs_pad, mask)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
xs_pad = self.model.after_norm(xs_pad)
|
||||
|
||||
return xs_pad, speech_lengths
|
||||
|
||||
def get_output_size(self):
|
||||
return self.model.encoders[0].size
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
feats = torch.randn(1, 100, self.feats_dim)
|
||||
return (feats)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['feats']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'feats': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
'encoder_out': {
|
||||
1: 'enc_out_length'
|
||||
},
|
||||
'predictor_weight':{
|
||||
1: 'pre_out_length'
|
||||
}
|
||||
|
||||
}
|
||||
296
funasr_local/export/models/encoder/fsmn_encoder.py
Normal file
296
funasr_local/export/models/encoder/fsmn_encoder.py
Normal file
@@ -0,0 +1,296 @@
|
||||
from typing import Tuple, Dict
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from funasr_local.models.encoder.fsmn_encoder import BasicBlock
|
||||
|
||||
class LinearTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LinearTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AffineTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(AffineTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RectifiedLinear(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(RectifiedLinear, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.relu(input)
|
||||
return out
|
||||
|
||||
|
||||
class FSMNBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
lorder=None,
|
||||
rorder=None,
|
||||
lstride=1,
|
||||
rstride=1,
|
||||
):
|
||||
super(FSMNBlock, self).__init__()
|
||||
|
||||
self.dim = input_dim
|
||||
|
||||
if lorder is None:
|
||||
return
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
|
||||
|
||||
if self.rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
def forward(self, input: torch.Tensor, cache: torch.Tensor):
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1) # B D T C
|
||||
|
||||
cache = cache.to(x_per.device)
|
||||
y_left = torch.cat((cache, x_per), dim=2)
|
||||
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
|
||||
y_left = self.conv_left(y_left)
|
||||
out = x_per + y_left
|
||||
|
||||
if self.conv_right is not None:
|
||||
# maybe need to check
|
||||
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
y_right = self.conv_right(y_right)
|
||||
out += y_right
|
||||
|
||||
out_per = out.permute(0, 3, 2, 1)
|
||||
output = out_per.squeeze(1)
|
||||
|
||||
return output, cache
|
||||
|
||||
|
||||
class BasicBlock_export(nn.Module):
|
||||
def __init__(self,
|
||||
model,
|
||||
):
|
||||
super(BasicBlock_export, self).__init__()
|
||||
self.linear = model.linear
|
||||
self.fsmn_block = model.fsmn_block
|
||||
self.affine = model.affine
|
||||
self.relu = model.relu
|
||||
|
||||
def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
|
||||
x = self.linear(input) # B T D
|
||||
# cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
|
||||
# if cache_layer_name not in in_cache:
|
||||
# in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
|
||||
x, out_cache = self.fsmn_block(x, in_cache)
|
||||
x = self.affine(x)
|
||||
x = self.relu(x)
|
||||
return x, out_cache
|
||||
|
||||
|
||||
# class FsmnStack(nn.Sequential):
|
||||
# def __init__(self, *args):
|
||||
# super(FsmnStack, self).__init__(*args)
|
||||
#
|
||||
# def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
|
||||
# x = input
|
||||
# for module in self._modules.values():
|
||||
# x = module(x, in_cache)
|
||||
# return x
|
||||
|
||||
|
||||
'''
|
||||
FSMN net for keyword spotting
|
||||
input_dim: input dimension
|
||||
linear_dim: fsmn input dimensionll
|
||||
proj_dim: fsmn projection dimension
|
||||
lorder: fsmn left order
|
||||
rorder: fsmn right order
|
||||
num_syn: output dimension
|
||||
fsmn_layers: no. of sequential fsmn layers
|
||||
'''
|
||||
|
||||
|
||||
class FSMN(nn.Module):
|
||||
def __init__(
|
||||
self, model,
|
||||
):
|
||||
super(FSMN, self).__init__()
|
||||
|
||||
# self.input_dim = input_dim
|
||||
# self.input_affine_dim = input_affine_dim
|
||||
# self.fsmn_layers = fsmn_layers
|
||||
# self.linear_dim = linear_dim
|
||||
# self.proj_dim = proj_dim
|
||||
# self.output_affine_dim = output_affine_dim
|
||||
# self.output_dim = output_dim
|
||||
#
|
||||
# self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
||||
# self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
||||
# self.relu = RectifiedLinear(linear_dim, linear_dim)
|
||||
# self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
|
||||
# range(fsmn_layers)])
|
||||
# self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
|
||||
# self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
|
||||
# self.softmax = nn.Softmax(dim=-1)
|
||||
self.in_linear1 = model.in_linear1
|
||||
self.in_linear2 = model.in_linear2
|
||||
self.relu = model.relu
|
||||
# self.fsmn = model.fsmn
|
||||
self.out_linear1 = model.out_linear1
|
||||
self.out_linear2 = model.out_linear2
|
||||
self.softmax = model.softmax
|
||||
self.fsmn = model.fsmn
|
||||
for i, d in enumerate(model.fsmn):
|
||||
if isinstance(d, BasicBlock):
|
||||
self.fsmn[i] = BasicBlock_export(d)
|
||||
|
||||
def fuse_modules(self):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor (B, T, D)
|
||||
in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
|
||||
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
|
||||
"""
|
||||
|
||||
x = self.in_linear1(input)
|
||||
x = self.in_linear2(x)
|
||||
x = self.relu(x)
|
||||
# x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
|
||||
out_caches = list()
|
||||
for i, d in enumerate(self.fsmn):
|
||||
in_cache = args[i]
|
||||
x, out_cache = d(x, in_cache)
|
||||
out_caches.append(out_cache)
|
||||
x = self.out_linear1(x)
|
||||
x = self.out_linear2(x)
|
||||
x = self.softmax(x)
|
||||
|
||||
return x, out_caches
|
||||
|
||||
|
||||
'''
|
||||
one deep fsmn layer
|
||||
dimproj: projection dimension, input and output dimension of memory blocks
|
||||
dimlinear: dimension of mapping layer
|
||||
lorder: left order
|
||||
rorder: right order
|
||||
lstride: left stride
|
||||
rstride: right stride
|
||||
'''
|
||||
|
||||
|
||||
class DFSMN(nn.Module):
|
||||
|
||||
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
|
||||
super(DFSMN, self).__init__()
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.expand = AffineTransform(dimproj, dimlinear)
|
||||
self.shrink = LinearTransform(dimlinear, dimproj)
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
|
||||
|
||||
if rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
def forward(self, input):
|
||||
f1 = F.relu(self.expand(input))
|
||||
p1 = self.shrink(f1)
|
||||
|
||||
x = torch.unsqueeze(p1, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
|
||||
|
||||
if self.conv_right is not None:
|
||||
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
|
||||
else:
|
||||
out = x_per + self.conv_left(y_left)
|
||||
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
output = input + out1.squeeze(1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
'''
|
||||
build stacked dfsmn layers
|
||||
'''
|
||||
|
||||
|
||||
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
|
||||
repeats = [
|
||||
nn.Sequential(
|
||||
DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
|
||||
for i in range(fsmn_layers)
|
||||
]
|
||||
|
||||
return nn.Sequential(*repeats)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
|
||||
print(fsmn)
|
||||
|
||||
num_params = sum(p.numel() for p in fsmn.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
x = torch.zeros(128, 200, 400) # batch-size * time * dim
|
||||
y, _ = fsmn(x) # batch-size * time * dim
|
||||
print('input shape: {}'.format(x.shape))
|
||||
print('output shape: {}'.format(y.shape))
|
||||
|
||||
print(fsmn.to_kaldi_net())
|
||||
213
funasr_local/export/models/encoder/sanm_encoder.py
Normal file
213
funasr_local/export/models/encoder/sanm_encoder.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr_local.export.utils.torch_function import MakePadMask
|
||||
from funasr_local.export.utils.torch_function import sequence_mask
|
||||
from funasr_local.modules.attention import MultiHeadedAttentionSANM
|
||||
from funasr_local.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
|
||||
from funasr_local.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
|
||||
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from funasr_local.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
|
||||
|
||||
|
||||
class SANMEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='encoder',
|
||||
onnx: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed = model.embed
|
||||
self.model = model
|
||||
self.feats_dim = feats_dim
|
||||
self._output_size = model._output_size
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
if hasattr(model, 'encoders0'):
|
||||
for i, d in enumerate(self.model.encoders0):
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
|
||||
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForward):
|
||||
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
|
||||
self.model.encoders0[i] = EncoderLayerSANM_export(d)
|
||||
|
||||
for i, d in enumerate(self.model.encoders):
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
|
||||
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForward):
|
||||
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
|
||||
self.model.encoders[i] = EncoderLayerSANM_export(d)
|
||||
|
||||
self.model_name = model_name
|
||||
self.num_heads = model.encoders[0].self_attn.h
|
||||
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
|
||||
|
||||
|
||||
def prepare_mask(self, mask):
|
||||
mask_3d_btd = mask[:, :, None]
|
||||
if len(mask.shape) == 2:
|
||||
mask_4d_bhlt = 1 - mask[:, None, None, :]
|
||||
elif len(mask.shape) == 3:
|
||||
mask_4d_bhlt = 1 - mask[:, None, :]
|
||||
mask_4d_bhlt = mask_4d_bhlt * -10000.0
|
||||
|
||||
return mask_3d_btd, mask_4d_bhlt
|
||||
|
||||
def forward(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
speech = speech * self._output_size ** 0.5
|
||||
mask = self.make_pad_mask(speech_lengths)
|
||||
mask = self.prepare_mask(mask)
|
||||
if self.embed is None:
|
||||
xs_pad = speech
|
||||
else:
|
||||
xs_pad = self.embed(speech)
|
||||
|
||||
encoder_outs = self.model.encoders0(xs_pad, mask)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
encoder_outs = self.model.encoders(xs_pad, mask)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
xs_pad = self.model.after_norm(xs_pad)
|
||||
|
||||
return xs_pad, speech_lengths
|
||||
|
||||
def get_output_size(self):
|
||||
return self.model.encoders[0].size
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
feats = torch.randn(1, 100, self.feats_dim)
|
||||
return (feats)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['feats']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'feats': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
'encoder_out': {
|
||||
1: 'enc_out_length'
|
||||
},
|
||||
'predictor_weight':{
|
||||
1: 'pre_out_length'
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
class SANMVadEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='encoder',
|
||||
onnx: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed = model.embed
|
||||
self.model = model
|
||||
self.feats_dim = feats_dim
|
||||
self._output_size = model._output_size
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
if hasattr(model, 'encoders0'):
|
||||
for i, d in enumerate(self.model.encoders0):
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
|
||||
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForward):
|
||||
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
|
||||
self.model.encoders0[i] = EncoderLayerSANM_export(d)
|
||||
|
||||
for i, d in enumerate(self.model.encoders):
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
|
||||
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForward):
|
||||
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
|
||||
self.model.encoders[i] = EncoderLayerSANM_export(d)
|
||||
|
||||
self.model_name = model_name
|
||||
self.num_heads = model.encoders[0].self_attn.h
|
||||
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
|
||||
|
||||
def prepare_mask(self, mask, sub_masks):
|
||||
mask_3d_btd = mask[:, :, None]
|
||||
mask_4d_bhlt = (1 - sub_masks) * -10000.0
|
||||
|
||||
return mask_3d_btd, mask_4d_bhlt
|
||||
|
||||
def forward(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
vad_masks: torch.Tensor,
|
||||
sub_masks: torch.Tensor,
|
||||
):
|
||||
speech = speech * self._output_size ** 0.5
|
||||
mask = self.make_pad_mask(speech_lengths)
|
||||
vad_masks = self.prepare_mask(mask, vad_masks)
|
||||
mask = self.prepare_mask(mask, sub_masks)
|
||||
|
||||
if self.embed is None:
|
||||
xs_pad = speech
|
||||
else:
|
||||
xs_pad = self.embed(speech)
|
||||
|
||||
encoder_outs = self.model.encoders0(xs_pad, mask)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
# encoder_outs = self.model.encoders(xs_pad, mask)
|
||||
for layer_idx, encoder_layer in enumerate(self.model.encoders):
|
||||
if layer_idx == len(self.model.encoders) - 1:
|
||||
mask = vad_masks
|
||||
encoder_outs = encoder_layer(xs_pad, mask)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
xs_pad = self.model.after_norm(xs_pad)
|
||||
|
||||
return xs_pad, speech_lengths
|
||||
|
||||
def get_output_size(self):
|
||||
return self.model.encoders[0].size
|
||||
|
||||
# def get_dummy_inputs(self):
|
||||
# feats = torch.randn(1, 100, self.feats_dim)
|
||||
# return (feats)
|
||||
#
|
||||
# def get_input_names(self):
|
||||
# return ['feats']
|
||||
#
|
||||
# def get_output_names(self):
|
||||
# return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
|
||||
#
|
||||
# def get_dynamic_axes(self):
|
||||
# return {
|
||||
# 'feats': {
|
||||
# 1: 'feats_length'
|
||||
# },
|
||||
# 'encoder_out': {
|
||||
# 1: 'enc_out_length'
|
||||
# },
|
||||
# 'predictor_weight': {
|
||||
# 1: 'pre_out_length'
|
||||
# }
|
||||
#
|
||||
# }
|
||||
0
funasr_local/export/models/modules/__init__.py
Normal file
0
funasr_local/export/models/modules/__init__.py
Normal file
71
funasr_local/export/models/modules/decoder_layer.py
Normal file
71
funasr_local/export/models/modules/decoder_layer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DecoderLayerSANM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = model.self_attn
|
||||
self.src_attn = model.src_attn
|
||||
self.feed_forward = model.feed_forward
|
||||
self.norm1 = model.norm1
|
||||
self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
|
||||
self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
|
||||
self.size = model.size
|
||||
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
||||
|
||||
residual = tgt
|
||||
tgt = self.norm1(tgt)
|
||||
tgt = self.feed_forward(tgt)
|
||||
|
||||
x = tgt
|
||||
if self.self_attn is not None:
|
||||
tgt = self.norm2(tgt)
|
||||
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
|
||||
x = residual + x
|
||||
|
||||
if self.src_attn is not None:
|
||||
residual = x
|
||||
x = self.norm3(x)
|
||||
x = residual + self.src_attn(x, memory, memory_mask)
|
||||
|
||||
|
||||
return x, tgt_mask, memory, memory_mask, cache
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.self_attn = model.self_attn
|
||||
self.src_attn = model.src_attn
|
||||
self.feed_forward = model.feed_forward
|
||||
self.norm1 = model.norm1
|
||||
self.norm2 = model.norm2
|
||||
self.norm3 = model.norm3
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
|
||||
residual = tgt
|
||||
tgt = self.norm1(tgt)
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
|
||||
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
|
||||
x = residual + self.src_attn(x, memory, memory, memory_mask)
|
||||
|
||||
residual = x
|
||||
x = self.norm3(x)
|
||||
x = residual + self.feed_forward(x)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
91
funasr_local/export/models/modules/encoder_layer.py
Normal file
91
funasr_local/export/models/modules/encoder_layer.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super().__init__()
|
||||
self.self_attn = model.self_attn
|
||||
self.feed_forward = model.feed_forward
|
||||
self.norm1 = model.norm1
|
||||
self.norm2 = model.norm2
|
||||
self.in_size = model.in_size
|
||||
self.size = model.size
|
||||
|
||||
def forward(self, x, mask):
|
||||
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = self.self_attn(x, mask)
|
||||
if self.in_size == self.size:
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.norm2(x)
|
||||
x = self.feed_forward(x)
|
||||
x = x + residual
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
class EncoderLayerConformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super().__init__()
|
||||
self.self_attn = model.self_attn
|
||||
self.feed_forward = model.feed_forward
|
||||
self.feed_forward_macaron = model.feed_forward_macaron
|
||||
self.conv_module = model.conv_module
|
||||
self.norm_ff = model.norm_ff
|
||||
self.norm_mha = model.norm_mha
|
||||
self.norm_ff_macaron = model.norm_ff_macaron
|
||||
self.norm_conv = model.norm_conv
|
||||
self.norm_final = model.norm_final
|
||||
self.size = model.size
|
||||
|
||||
def forward(self, x, mask):
|
||||
if isinstance(x, tuple):
|
||||
x, pos_emb = x[0], x[1]
|
||||
else:
|
||||
x, pos_emb = x, None
|
||||
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.feed_forward_macaron(x) * 0.5
|
||||
|
||||
residual = x
|
||||
x = self.norm_mha(x)
|
||||
|
||||
x_q = x
|
||||
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
x = residual + x_att
|
||||
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
x = self.norm_conv(x)
|
||||
x = residual + self.conv_module(x)
|
||||
|
||||
residual = x
|
||||
x = self.norm_ff(x)
|
||||
x = residual + self.feed_forward(x) * 0.5
|
||||
|
||||
x = self.norm_final(x)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
||||
31
funasr_local/export/models/modules/feedforward.py
Normal file
31
funasr_local/export/models/modules/feedforward.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.w_1 = model.w_1
|
||||
self.w_2 = model.w_2
|
||||
self.activation = model.activation
|
||||
|
||||
def forward(self, x):
|
||||
x = self.activation(self.w_1(x))
|
||||
x = self.w_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionwiseFeedForwardDecoderSANM(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.w_1 = model.w_1
|
||||
self.w_2 = model.w_2
|
||||
self.activation = model.activation
|
||||
self.norm = model.norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.activation(self.w_1(x))
|
||||
x = self.w_2(self.norm(x))
|
||||
return x
|
||||
243
funasr_local/export/models/modules/multihead_att.py
Normal file
243
funasr_local/export/models/modules/multihead_att.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import os
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MultiHeadedAttentionSANM(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.d_k = model.d_k
|
||||
self.h = model.h
|
||||
self.linear_out = model.linear_out
|
||||
self.linear_q_k_v = model.linear_q_k_v
|
||||
self.fsmn_block = model.fsmn_block
|
||||
self.pad_fn = model.pad_fn
|
||||
|
||||
self.attn = None
|
||||
self.all_head_size = self.h * self.d_k
|
||||
|
||||
def forward(self, x, mask):
|
||||
mask_3d_btd, mask_4d_bhlt = mask
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
|
||||
q_h = q_h * self.d_k**(-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward_qkv(self, x):
|
||||
q_k_v = self.linear_q_k_v(x)
|
||||
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
||||
q_h = self.transpose_for_scores(q)
|
||||
k_h = self.transpose_for_scores(k)
|
||||
v_h = self.transpose_for_scores(v)
|
||||
return q_h, k_h, v_h, v
|
||||
|
||||
def forward_fsmn(self, inputs, mask):
|
||||
# b, t, d = inputs.size()
|
||||
# mask = torch.reshape(mask, (b, -1, 1))
|
||||
inputs = inputs * mask
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = x + inputs
|
||||
x = x * mask
|
||||
return x
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
scores = scores + mask
|
||||
|
||||
self.attn = torch.softmax(scores, dim=-1)
|
||||
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
return self.linear_out(context_layer) # (batch, time1, d_model)
|
||||
|
||||
|
||||
def preprocess_for_attn(x, mask, cache, pad_fn):
|
||||
x = x * mask
|
||||
x = x.transpose(1, 2)
|
||||
if cache is None:
|
||||
x = pad_fn(x)
|
||||
else:
|
||||
x = torch.cat((cache[:, :, 1:], x), dim=2)
|
||||
cache = x
|
||||
return x, cache
|
||||
|
||||
|
||||
torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
|
||||
if torch_version >= (1, 8):
|
||||
import torch.fx
|
||||
torch.fx.wrap('preprocess_for_attn')
|
||||
|
||||
|
||||
class MultiHeadedAttentionSANMDecoder(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.fsmn_block = model.fsmn_block
|
||||
self.pad_fn = model.pad_fn
|
||||
self.kernel_size = model.kernel_size
|
||||
self.attn = None
|
||||
|
||||
def forward(self, inputs, mask, cache=None):
|
||||
x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = x + inputs
|
||||
x = x * mask
|
||||
return x, cache
|
||||
|
||||
|
||||
class MultiHeadedAttentionCrossAtt(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.d_k = model.d_k
|
||||
self.h = model.h
|
||||
self.linear_q = model.linear_q
|
||||
self.linear_k_v = model.linear_k_v
|
||||
self.linear_out = model.linear_out
|
||||
self.attn = None
|
||||
self.all_head_size = self.h * self.d_k
|
||||
|
||||
def forward(self, x, memory, memory_mask):
|
||||
q, k, v = self.forward_qkv(x, memory)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, memory_mask)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward_qkv(self, x, memory):
|
||||
q = self.linear_q(x)
|
||||
|
||||
k_v = self.linear_k_v(memory)
|
||||
k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
|
||||
q = self.transpose_for_scores(q)
|
||||
k = self.transpose_for_scores(k)
|
||||
v = self.transpose_for_scores(v)
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
scores = scores + mask
|
||||
|
||||
self.attn = torch.softmax(scores, dim=-1)
|
||||
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
return self.linear_out(context_layer) # (batch, time1, d_model)
|
||||
|
||||
|
||||
class OnnxMultiHeadedAttention(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.d_k = model.d_k
|
||||
self.h = model.h
|
||||
self.linear_q = model.linear_q
|
||||
self.linear_k = model.linear_k
|
||||
self.linear_v = model.linear_v
|
||||
self.linear_out = model.linear_out
|
||||
self.attn = None
|
||||
self.all_head_size = self.h * self.d_k
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
q = self.linear_q(query)
|
||||
k = self.linear_k(key)
|
||||
v = self.linear_v(value)
|
||||
q = self.transpose_for_scores(q)
|
||||
k = self.transpose_for_scores(k)
|
||||
v = self.transpose_for_scores(v)
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
scores = scores + mask
|
||||
|
||||
self.attn = torch.softmax(scores, dim=-1)
|
||||
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
return self.linear_out(context_layer) # (batch, time1, d_model)
|
||||
|
||||
|
||||
class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
|
||||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
self.linear_pos = model.linear_pos
|
||||
self.pos_bias_u = model.pos_bias_u
|
||||
self.pos_bias_v = model.pos_bias_v
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time1)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
def rel_shift(self, x):
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)[
|
||||
:, :, :, : x.size(-1) // 2 + 1
|
||||
] # only keep the positions from 0 to time2
|
||||
return x
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
scores = scores + mask
|
||||
|
||||
self.attn = torch.softmax(scores, dim=-1)
|
||||
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
return self.linear_out(context_layer) # (batch, time1, d_model)
|
||||
|
||||
0
funasr_local/export/models/predictor/__init__.py
Normal file
0
funasr_local/export/models/predictor/__init__.py
Normal file
288
funasr_local/export/models/predictor/cif.py
Normal file
288
funasr_local/export/models/predictor/cif.py
Normal file
@@ -0,0 +1,288 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
mask = mask.detach()
|
||||
|
||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
||||
|
||||
def sequence_mask_scripts(lengths, maxlen:int):
|
||||
row_vector = torch.arange(0, maxlen, 1).type(lengths.dtype).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
return mask.type(torch.float32).to(lengths.device)
|
||||
|
||||
class CifPredictorV2(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
|
||||
self.pad = model.pad
|
||||
self.cif_conv1d = model.cif_conv1d
|
||||
self.cif_output = model.cif_output
|
||||
self.threshold = model.threshold
|
||||
self.smooth_factor = model.smooth_factor
|
||||
self.noise_threshold = model.noise_threshold
|
||||
self.tail_threshold = model.tail_threshold
|
||||
|
||||
def forward(self, hidden: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
alphas = alphas.squeeze(-1)
|
||||
token_num = alphas.sum(-1)
|
||||
|
||||
mask = mask.squeeze(-1)
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
b, t, d = hidden.size()
|
||||
tail_threshold = self.tail_threshold
|
||||
|
||||
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
|
||||
ones_t = torch.ones_like(zeros_t)
|
||||
|
||||
mask_1 = torch.cat([mask, zeros_t], dim=1)
|
||||
mask_2 = torch.cat([ones_t, mask], dim=1)
|
||||
mask = mask_2 - mask_1
|
||||
tail_threshold = mask * tail_threshold
|
||||
alphas = torch.cat([alphas, zeros_t], dim=1)
|
||||
alphas = torch.add(alphas, tail_threshold)
|
||||
|
||||
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
||||
hidden = torch.cat([hidden, zeros], dim=1)
|
||||
token_num = alphas.sum(dim=-1)
|
||||
token_num_floor = torch.floor(token_num)
|
||||
|
||||
return hidden, alphas, token_num_floor
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
# def cif(hidden, alphas, threshold: float):
|
||||
# batch_size, len_time, hidden_size = hidden.size()
|
||||
# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
|
||||
#
|
||||
# # loop varss
|
||||
# integrate = torch.zeros([batch_size], device=hidden.device)
|
||||
# frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
|
||||
# # intermediate vars along time
|
||||
# list_fires = []
|
||||
# list_frames = []
|
||||
#
|
||||
# for t in range(len_time):
|
||||
# alpha = alphas[:, t]
|
||||
# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
|
||||
#
|
||||
# integrate += alpha
|
||||
# list_fires.append(integrate)
|
||||
#
|
||||
# fire_place = integrate >= threshold
|
||||
# integrate = torch.where(fire_place,
|
||||
# integrate - torch.ones([batch_size], device=hidden.device),
|
||||
# integrate)
|
||||
# cur = torch.where(fire_place,
|
||||
# distribution_completion,
|
||||
# alpha)
|
||||
# remainds = alpha - cur
|
||||
#
|
||||
# frame += cur[:, None] * hidden[:, t, :]
|
||||
# list_frames.append(frame)
|
||||
# frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
|
||||
# remainds[:, None] * hidden[:, t, :],
|
||||
# frame)
|
||||
#
|
||||
# fires = torch.stack(list_fires, 1)
|
||||
# frames = torch.stack(list_frames, 1)
|
||||
# list_ls = []
|
||||
# len_labels = torch.floor(alphas.sum(-1)).int()
|
||||
# max_label_len = len_labels.max()
|
||||
# for b in range(batch_size):
|
||||
# fire = fires[b, :]
|
||||
# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
|
||||
# pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
|
||||
# list_ls.append(torch.cat([l, pad_l], 0))
|
||||
# return torch.stack(list_ls, 0), fires
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def cif(hidden, alphas, threshold: float):
|
||||
batch_size, len_time, hidden_size = hidden.size()
|
||||
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
|
||||
|
||||
# loop varss
|
||||
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
|
||||
frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
|
||||
# intermediate vars along time
|
||||
list_fires = []
|
||||
list_frames = []
|
||||
|
||||
for t in range(len_time):
|
||||
alpha = alphas[:, t]
|
||||
distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
|
||||
|
||||
integrate += alpha
|
||||
list_fires.append(integrate)
|
||||
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(fire_place,
|
||||
integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
|
||||
integrate)
|
||||
cur = torch.where(fire_place,
|
||||
distribution_completion,
|
||||
alpha)
|
||||
remainds = alpha - cur
|
||||
|
||||
frame += cur[:, None] * hidden[:, t, :]
|
||||
list_frames.append(frame)
|
||||
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
|
||||
remainds[:, None] * hidden[:, t, :],
|
||||
frame)
|
||||
|
||||
fires = torch.stack(list_fires, 1)
|
||||
frames = torch.stack(list_frames, 1)
|
||||
|
||||
fire_idxs = fires >= threshold
|
||||
frame_fires = torch.zeros_like(hidden)
|
||||
max_label_len = frames[0, fire_idxs[0]].size(0)
|
||||
for b in range(batch_size):
|
||||
frame_fire = frames[b, fire_idxs[b]]
|
||||
frame_len = frame_fire.size(0)
|
||||
frame_fires[b, :frame_len, :] = frame_fire
|
||||
|
||||
if frame_len >= max_label_len:
|
||||
max_label_len = frame_len
|
||||
frame_fires = frame_fires[:, :max_label_len, :]
|
||||
return frame_fires, fires
|
||||
|
||||
|
||||
class CifPredictorV3(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
|
||||
self.pad = model.pad
|
||||
self.cif_conv1d = model.cif_conv1d
|
||||
self.cif_output = model.cif_output
|
||||
self.threshold = model.threshold
|
||||
self.smooth_factor = model.smooth_factor
|
||||
self.noise_threshold = model.noise_threshold
|
||||
self.tail_threshold = model.tail_threshold
|
||||
|
||||
self.upsample_times = model.upsample_times
|
||||
self.upsample_cnn = model.upsample_cnn
|
||||
self.blstm = model.blstm
|
||||
self.cif_output2 = model.cif_output2
|
||||
self.smooth_factor2 = model.smooth_factor2
|
||||
self.noise_threshold2 = model.noise_threshold2
|
||||
|
||||
def forward(self, hidden: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
alphas = alphas * mask
|
||||
alphas = alphas.squeeze(-1)
|
||||
token_num = alphas.sum(-1)
|
||||
|
||||
mask = mask.squeeze(-1)
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
|
||||
h = hidden
|
||||
b = hidden.shape[0]
|
||||
context = h.transpose(1, 2)
|
||||
|
||||
# generate alphas2
|
||||
_output = context
|
||||
output2 = self.upsample_cnn(_output)
|
||||
output2 = output2.transpose(1, 2)
|
||||
output2, (_, _) = self.blstm(output2)
|
||||
alphas2 = torch.sigmoid(self.cif_output2(output2))
|
||||
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
|
||||
|
||||
mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
|
||||
mask = mask.unsqueeze(-1)
|
||||
alphas2 = alphas2 * mask
|
||||
alphas2 = alphas2.squeeze(-1)
|
||||
_token_num = alphas2.sum(-1)
|
||||
alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
|
||||
# upsampled alphas and cif_peak
|
||||
us_alphas = alphas2
|
||||
us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
|
||||
return us_alphas, us_cif_peak
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
b, t, d = hidden.size()
|
||||
tail_threshold = self.tail_threshold
|
||||
|
||||
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
|
||||
ones_t = torch.ones_like(zeros_t)
|
||||
|
||||
mask_1 = torch.cat([mask, zeros_t], dim=1)
|
||||
mask_2 = torch.cat([ones_t, mask], dim=1)
|
||||
mask = mask_2 - mask_1
|
||||
tail_threshold = mask * tail_threshold
|
||||
alphas = torch.cat([alphas, zeros_t], dim=1)
|
||||
alphas = torch.add(alphas, tail_threshold)
|
||||
|
||||
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
||||
hidden = torch.cat([hidden, zeros], dim=1)
|
||||
token_num = alphas.sum(dim=-1)
|
||||
token_num_floor = torch.floor(token_num)
|
||||
|
||||
return hidden, alphas, token_num_floor
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def cif_wo_hidden(alphas, threshold: float):
|
||||
batch_size, len_time = alphas.size()
|
||||
|
||||
# loop varss
|
||||
integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
|
||||
# intermediate vars along time
|
||||
list_fires = []
|
||||
|
||||
for t in range(len_time):
|
||||
alpha = alphas[:, t]
|
||||
|
||||
integrate += alpha
|
||||
list_fires.append(integrate)
|
||||
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(fire_place,
|
||||
integrate - torch.ones([batch_size], device=alphas.device),
|
||||
integrate)
|
||||
|
||||
fires = torch.stack(list_fires, 1)
|
||||
return fires
|
||||
Reference in New Issue
Block a user