mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
0
funasr_local/modules/beam_search/__init__.py
Normal file
0
funasr_local/modules/beam_search/__init__.py
Normal file
348
funasr_local/modules/beam_search/batch_beam_search.py
Normal file
348
funasr_local/modules/beam_search/batch_beam_search.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Parallel beam search module."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import NamedTuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search import BeamSearch
|
||||
from funasr_local.modules.beam_search.beam_search import Hypothesis
|
||||
|
||||
|
||||
class BatchHypothesis(NamedTuple):
|
||||
"""Batchfied/Vectorized hypothesis data type."""
|
||||
|
||||
yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
|
||||
score: torch.Tensor = torch.tensor([]) # (batch,)
|
||||
length: torch.Tensor = torch.tensor([]) # (batch,)
|
||||
scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
|
||||
states: Dict[str, Dict] = dict()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return a batch size."""
|
||||
return len(self.length)
|
||||
|
||||
|
||||
class BatchBeamSearch(BeamSearch):
|
||||
"""Batch beam search implementation."""
|
||||
|
||||
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
|
||||
"""Convert list to batch."""
|
||||
if len(hyps) == 0:
|
||||
return BatchHypothesis()
|
||||
return BatchHypothesis(
|
||||
yseq=pad_sequence(
|
||||
[h.yseq for h in hyps], batch_first=True, padding_value=self.eos
|
||||
),
|
||||
length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
|
||||
score=torch.tensor([h.score for h in hyps]),
|
||||
scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
|
||||
states={k: [h.states[k] for h in hyps] for k in self.scorers},
|
||||
)
|
||||
|
||||
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
|
||||
return BatchHypothesis(
|
||||
yseq=hyps.yseq[ids],
|
||||
score=hyps.score[ids],
|
||||
length=hyps.length[ids],
|
||||
scores={k: v[ids] for k, v in hyps.scores.items()},
|
||||
states={
|
||||
k: [self.scorers[k].select_state(v, i) for i in ids]
|
||||
for k, v in hyps.states.items()
|
||||
},
|
||||
)
|
||||
|
||||
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
|
||||
return Hypothesis(
|
||||
yseq=hyps.yseq[i, : hyps.length[i]],
|
||||
score=hyps.score[i],
|
||||
scores={k: v[i] for k, v in hyps.scores.items()},
|
||||
states={
|
||||
k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
|
||||
},
|
||||
)
|
||||
|
||||
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
|
||||
"""Revert batch to list."""
|
||||
return [
|
||||
Hypothesis(
|
||||
yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
|
||||
score=batch_hyps.score[i],
|
||||
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
|
||||
states={
|
||||
k: v.select_state(batch_hyps.states[k], i)
|
||||
for k, v in self.scorers.items()
|
||||
},
|
||||
)
|
||||
for i in range(len(batch_hyps.length))
|
||||
]
|
||||
|
||||
def batch_beam(
|
||||
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Batch-compute topk full token ids and partial token ids.
|
||||
|
||||
Args:
|
||||
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
||||
Its shape is `(n_beam, self.vocab_size)`.
|
||||
ids (torch.Tensor): The partial token ids to compute topk.
|
||||
Its shape is `(n_beam, self.pre_beam_size)`.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
The topk full (prev_hyp, new_token) ids
|
||||
and partial (prev_hyp, new_token) ids.
|
||||
Their shapes are all `(self.beam_size,)`
|
||||
|
||||
"""
|
||||
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
|
||||
# Because of the flatten above, `top_ids` is organized as:
|
||||
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
|
||||
# where V is `self.n_vocab` and K is `self.beam_size`
|
||||
prev_hyp_ids = top_ids // self.n_vocab
|
||||
new_token_ids = top_ids % self.n_vocab
|
||||
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
|
||||
|
||||
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
|
||||
"""Get an initial hypothesis data.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoder output feature
|
||||
|
||||
Returns:
|
||||
Hypothesis: The initial hypothesis.
|
||||
|
||||
"""
|
||||
init_states = dict()
|
||||
init_scores = dict()
|
||||
for k, d in self.scorers.items():
|
||||
init_states[k] = d.batch_init_state(x)
|
||||
init_scores[k] = 0.0
|
||||
return self.batchfy(
|
||||
[
|
||||
Hypothesis(
|
||||
score=0.0,
|
||||
scores=init_scores,
|
||||
states=init_states,
|
||||
yseq=torch.tensor([self.sos], device=x.device),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def score_full(
|
||||
self, hyp: BatchHypothesis, x: torch.Tensor
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.full_scorers.items():
|
||||
scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
|
||||
return scores, states
|
||||
|
||||
def score_partial(
|
||||
self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
ids (torch.Tensor): 2D tensor of new partial tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.part_scorers.items():
|
||||
scores[k], states[k] = d.batch_score_partial(
|
||||
hyp.yseq, ids, hyp.states[k], x
|
||||
)
|
||||
return scores, states
|
||||
|
||||
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
||||
"""Merge states for new hypothesis.
|
||||
|
||||
Args:
|
||||
states: states of `self.full_scorers`
|
||||
part_states: states of `self.part_scorers`
|
||||
part_idx (int): The new token id for `part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are states of the scorers.
|
||||
|
||||
"""
|
||||
new_states = dict()
|
||||
for k, v in states.items():
|
||||
new_states[k] = v
|
||||
for k, v in part_states.items():
|
||||
new_states[k] = v
|
||||
return new_states
|
||||
|
||||
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
|
||||
"""Search new tokens for running hypotheses and encoded speech x.
|
||||
|
||||
Args:
|
||||
running_hyps (BatchHypothesis): Running hypotheses on beam
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
|
||||
Returns:
|
||||
BatchHypothesis: Best sorted hypotheses
|
||||
|
||||
"""
|
||||
n_batch = len(running_hyps)
|
||||
part_ids = None # no pre-beam
|
||||
# batch scoring
|
||||
weighted_scores = torch.zeros(
|
||||
n_batch, self.n_vocab, dtype=x.dtype, device=x.device
|
||||
)
|
||||
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
|
||||
for k in self.full_scorers:
|
||||
weighted_scores += self.weights[k] * scores[k]
|
||||
# partial scoring
|
||||
if self.do_pre_beam:
|
||||
pre_beam_scores = (
|
||||
weighted_scores
|
||||
if self.pre_beam_score_key == "full"
|
||||
else scores[self.pre_beam_score_key]
|
||||
)
|
||||
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
|
||||
# NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
|
||||
# full-size score matrices, which has non-zero scores for part_ids and zeros
|
||||
# for others.
|
||||
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
|
||||
for k in self.part_scorers:
|
||||
weighted_scores += self.weights[k] * part_scores[k]
|
||||
# add previous hyp scores
|
||||
weighted_scores += running_hyps.score.to(
|
||||
dtype=x.dtype, device=x.device
|
||||
).unsqueeze(1)
|
||||
|
||||
# TODO(karita): do not use list. use batch instead
|
||||
# see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
|
||||
# update hyps
|
||||
best_hyps = []
|
||||
prev_hyps = self.unbatchfy(running_hyps)
|
||||
for (
|
||||
full_prev_hyp_id,
|
||||
full_new_token_id,
|
||||
part_prev_hyp_id,
|
||||
part_new_token_id,
|
||||
) in zip(*self.batch_beam(weighted_scores, part_ids)):
|
||||
prev_hyp = prev_hyps[full_prev_hyp_id]
|
||||
best_hyps.append(
|
||||
Hypothesis(
|
||||
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
|
||||
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
|
||||
scores=self.merge_scores(
|
||||
prev_hyp.scores,
|
||||
{k: v[full_prev_hyp_id] for k, v in scores.items()},
|
||||
full_new_token_id,
|
||||
{k: v[part_prev_hyp_id] for k, v in part_scores.items()},
|
||||
part_new_token_id,
|
||||
),
|
||||
states=self.merge_states(
|
||||
{
|
||||
k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
|
||||
for k, v in states.items()
|
||||
},
|
||||
{
|
||||
k: self.part_scorers[k].select_state(
|
||||
v, part_prev_hyp_id, part_new_token_id
|
||||
)
|
||||
for k, v in part_states.items()
|
||||
},
|
||||
part_new_token_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
return self.batchfy(best_hyps)
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
i: int,
|
||||
maxlen: int,
|
||||
maxlenratio: float,
|
||||
running_hyps: BatchHypothesis,
|
||||
ended_hyps: List[Hypothesis],
|
||||
) -> BatchHypothesis:
|
||||
"""Perform post-processing of beam search iterations.
|
||||
|
||||
Args:
|
||||
i (int): The length of hypothesis tokens.
|
||||
maxlen (int): The maximum length of tokens in beam search.
|
||||
maxlenratio (int): The maximum length ratio in beam search.
|
||||
running_hyps (BatchHypothesis): The running hypotheses in beam search.
|
||||
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
||||
|
||||
Returns:
|
||||
BatchHypothesis: The new running hypotheses.
|
||||
|
||||
"""
|
||||
n_batch = running_hyps.yseq.shape[0]
|
||||
logging.debug(f"the number of running hypothes: {n_batch}")
|
||||
if self.token_list is not None:
|
||||
logging.debug(
|
||||
"best hypo: "
|
||||
+ "".join(
|
||||
[
|
||||
self.token_list[x]
|
||||
for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
|
||||
]
|
||||
)
|
||||
)
|
||||
# add eos in the final loop to avoid that there are no ended hyps
|
||||
if i == maxlen - 1:
|
||||
logging.info("adding <eos> in the last position in the loop")
|
||||
yseq_eos = torch.cat(
|
||||
(
|
||||
running_hyps.yseq,
|
||||
torch.full(
|
||||
(n_batch, 1),
|
||||
self.eos,
|
||||
device=running_hyps.yseq.device,
|
||||
dtype=torch.int64,
|
||||
),
|
||||
),
|
||||
1,
|
||||
)
|
||||
running_hyps.yseq.resize_as_(yseq_eos)
|
||||
running_hyps.yseq[:] = yseq_eos
|
||||
running_hyps.length[:] = yseq_eos.shape[1]
|
||||
|
||||
# add ended hypotheses to a final list, and removed them from current hypotheses
|
||||
# (this will be a probmlem, number of hyps < beam)
|
||||
is_eos = (
|
||||
running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
|
||||
== self.eos
|
||||
)
|
||||
for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
|
||||
hyp = self._select(running_hyps, b)
|
||||
ended_hyps.append(hyp)
|
||||
remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1)
|
||||
return self._batch_select(running_hyps, remained_ids)
|
||||
270
funasr_local/modules/beam_search/batch_beam_search_online_sim.py
Normal file
270
funasr_local/modules/beam_search/batch_beam_search_online_sim.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Parallel beam search module for online simulation."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.beam_search.batch_beam_search import BatchBeamSearch
|
||||
from funasr_local.modules.beam_search.beam_search import Hypothesis
|
||||
from funasr_local.models.e2e_asr_common import end_detect
|
||||
|
||||
|
||||
class BatchBeamSearchOnlineSim(BatchBeamSearch):
|
||||
"""Online beam search implementation.
|
||||
|
||||
This simulates streaming decoding.
|
||||
It requires encoded features of entire utterance and
|
||||
extracts block by block from it as it shoud be done
|
||||
in streaming processing.
|
||||
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
|
||||
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
|
||||
(https://arxiv.org/abs/2006.14941).
|
||||
"""
|
||||
|
||||
def set_streaming_config(self, asr_config: str):
|
||||
"""Set config file for streaming decoding.
|
||||
|
||||
Args:
|
||||
asr_config (str): The config file for asr training
|
||||
|
||||
"""
|
||||
train_config_file = Path(asr_config)
|
||||
self.block_size = None
|
||||
self.hop_size = None
|
||||
self.look_ahead = None
|
||||
config = None
|
||||
with train_config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
if "encoder_conf" in args.keys():
|
||||
if "block_size" in args["encoder_conf"].keys():
|
||||
self.block_size = args["encoder_conf"]["block_size"]
|
||||
if "hop_size" in args["encoder_conf"].keys():
|
||||
self.hop_size = args["encoder_conf"]["hop_size"]
|
||||
if "look_ahead" in args["encoder_conf"].keys():
|
||||
self.look_ahead = args["encoder_conf"]["look_ahead"]
|
||||
elif "config" in args.keys():
|
||||
config = args["config"]
|
||||
if config is None:
|
||||
logging.info(
|
||||
"Cannot find config file for streaming decoding: "
|
||||
+ "apply batch beam search instead."
|
||||
)
|
||||
return
|
||||
if (
|
||||
self.block_size is None or self.hop_size is None or self.look_ahead is None
|
||||
) and config is not None:
|
||||
config_file = Path(config)
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
if "encoder_conf" in args.keys():
|
||||
enc_args = args["encoder_conf"]
|
||||
if enc_args and "block_size" in enc_args:
|
||||
self.block_size = enc_args["block_size"]
|
||||
if enc_args and "hop_size" in enc_args:
|
||||
self.hop_size = enc_args["hop_size"]
|
||||
if enc_args and "look_ahead" in enc_args:
|
||||
self.look_ahead = enc_args["look_ahead"]
|
||||
|
||||
def set_block_size(self, block_size: int):
|
||||
"""Set block size for streaming decoding.
|
||||
|
||||
Args:
|
||||
block_size (int): The block size of encoder
|
||||
"""
|
||||
self.block_size = block_size
|
||||
|
||||
def set_hop_size(self, hop_size: int):
|
||||
"""Set hop size for streaming decoding.
|
||||
|
||||
Args:
|
||||
hop_size (int): The hop size of encoder
|
||||
"""
|
||||
self.hop_size = hop_size
|
||||
|
||||
def set_look_ahead(self, look_ahead: int):
|
||||
"""Set look ahead size for streaming decoding.
|
||||
|
||||
Args:
|
||||
look_ahead (int): The look ahead size of encoder
|
||||
"""
|
||||
self.look_ahead = look_ahead
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform beam search.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
maxlenratio (float): Input length ratio to obtain max output length.
|
||||
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||
to automatically find maximum hypothesis lengths
|
||||
minlenratio (float): Input length ratio to obtain min output length.
|
||||
|
||||
Returns:
|
||||
list[Hypothesis]: N-best decoding results
|
||||
|
||||
"""
|
||||
self.conservative = True # always true
|
||||
|
||||
if self.block_size and self.hop_size and self.look_ahead:
|
||||
cur_end_frame = int(self.block_size - self.look_ahead)
|
||||
else:
|
||||
cur_end_frame = x.shape[0]
|
||||
process_idx = 0
|
||||
if cur_end_frame < x.shape[0]:
|
||||
h = x.narrow(0, 0, cur_end_frame)
|
||||
else:
|
||||
h = x
|
||||
|
||||
# set length bounds
|
||||
if maxlenratio == 0:
|
||||
maxlen = x.shape[0]
|
||||
else:
|
||||
maxlen = max(1, int(maxlenratio * x.size(0)))
|
||||
minlen = int(minlenratio * x.size(0))
|
||||
logging.info("decoder input length: " + str(x.shape[0]))
|
||||
logging.info("max output length: " + str(maxlen))
|
||||
logging.info("min output length: " + str(minlen))
|
||||
|
||||
# main loop of prefix search
|
||||
running_hyps = self.init_hyp(h)
|
||||
prev_hyps = []
|
||||
ended_hyps = []
|
||||
prev_repeat = False
|
||||
|
||||
continue_decode = True
|
||||
|
||||
while continue_decode:
|
||||
move_to_next_block = False
|
||||
if cur_end_frame < x.shape[0]:
|
||||
h = x.narrow(0, 0, cur_end_frame)
|
||||
else:
|
||||
h = x
|
||||
|
||||
# extend states for ctc
|
||||
self.extend(h, running_hyps)
|
||||
|
||||
while process_idx < maxlen:
|
||||
logging.debug("position " + str(process_idx))
|
||||
best = self.search(running_hyps, h)
|
||||
|
||||
if process_idx == maxlen - 1:
|
||||
# end decoding
|
||||
running_hyps = self.post_process(
|
||||
process_idx, maxlen, maxlenratio, best, ended_hyps
|
||||
)
|
||||
n_batch = best.yseq.shape[0]
|
||||
local_ended_hyps = []
|
||||
is_local_eos = (
|
||||
best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
|
||||
)
|
||||
for i in range(is_local_eos.shape[0]):
|
||||
if is_local_eos[i]:
|
||||
hyp = self._select(best, i)
|
||||
local_ended_hyps.append(hyp)
|
||||
# NOTE(tsunoo): check repetitions here
|
||||
# This is a implicit implementation of
|
||||
# Eq (11) in https://arxiv.org/abs/2006.14941
|
||||
# A flag prev_repeat is used instead of using set
|
||||
elif (
|
||||
not prev_repeat
|
||||
and best.yseq[i, -1] in best.yseq[i, :-1]
|
||||
and cur_end_frame < x.shape[0]
|
||||
):
|
||||
move_to_next_block = True
|
||||
prev_repeat = True
|
||||
if maxlenratio == 0.0 and end_detect(
|
||||
[lh.asdict() for lh in local_ended_hyps], process_idx
|
||||
):
|
||||
logging.info(f"end detected at {process_idx}")
|
||||
continue_decode = False
|
||||
break
|
||||
if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
|
||||
move_to_next_block = True
|
||||
|
||||
if move_to_next_block:
|
||||
if (
|
||||
self.hop_size
|
||||
and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
|
||||
< x.shape[0]
|
||||
):
|
||||
cur_end_frame += int(self.hop_size)
|
||||
else:
|
||||
cur_end_frame = x.shape[0]
|
||||
logging.debug("Going to next block: %d", cur_end_frame)
|
||||
if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
|
||||
running_hyps = prev_hyps
|
||||
process_idx -= 1
|
||||
prev_hyps = []
|
||||
break
|
||||
|
||||
prev_repeat = False
|
||||
prev_hyps = running_hyps
|
||||
running_hyps = self.post_process(
|
||||
process_idx, maxlen, maxlenratio, best, ended_hyps
|
||||
)
|
||||
|
||||
if cur_end_frame >= x.shape[0]:
|
||||
for hyp in local_ended_hyps:
|
||||
ended_hyps.append(hyp)
|
||||
|
||||
if len(running_hyps) == 0:
|
||||
logging.info("no hypothesis. Finish decoding.")
|
||||
continue_decode = False
|
||||
break
|
||||
else:
|
||||
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
||||
# increment number
|
||||
process_idx += 1
|
||||
|
||||
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
||||
# check the number of hypotheses reaching to eos
|
||||
if len(nbest_hyps) == 0:
|
||||
logging.warning(
|
||||
"there is no N-best results, perform recognition "
|
||||
"again with smaller minlenratio."
|
||||
)
|
||||
return (
|
||||
[]
|
||||
if minlenratio < 0.1
|
||||
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
||||
)
|
||||
|
||||
# report the best result
|
||||
best = nbest_hyps[0]
|
||||
for k, v in best.scores.items():
|
||||
logging.info(
|
||||
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
||||
)
|
||||
logging.info(f"total log probability: {best.score:.2f}")
|
||||
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
||||
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
||||
if self.token_list is not None:
|
||||
logging.info(
|
||||
"best hypo: "
|
||||
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
||||
+ "\n"
|
||||
)
|
||||
return nbest_hyps
|
||||
|
||||
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
|
||||
"""Extend probabilities and states with more encoded chunks.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The extended encoder output feature
|
||||
hyps (Hypothesis): Current list of hypothesis
|
||||
|
||||
Returns:
|
||||
Hypothesis: The extended hypothesis
|
||||
|
||||
"""
|
||||
for k, d in self.scorers.items():
|
||||
if hasattr(d, "extend_prob"):
|
||||
d.extend_prob(x)
|
||||
if hasattr(d, "extend_state"):
|
||||
hyps.states[k] = d.extend_state(hyps.states[k])
|
||||
1400
funasr_local/modules/beam_search/beam_search.py
Normal file
1400
funasr_local/modules/beam_search/beam_search.py
Normal file
File diff suppressed because it is too large
Load Diff
704
funasr_local/modules/beam_search/beam_search_transducer.py
Normal file
704
funasr_local/modules/beam_search/beam_search_transducer.py
Normal file
@@ -0,0 +1,704 @@
|
||||
"""Search algorithms for Transducer models."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr_local.models.joint_net.joint_network import JointNetwork
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
"""Default hypothesis definition for Transducer search algorithms.
|
||||
|
||||
Args:
|
||||
score: Total log-probability.
|
||||
yseq: Label sequence as integer ID sequence.
|
||||
dec_state: RNNDecoder or StatelessDecoder state.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None) or None
|
||||
lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
|
||||
|
||||
"""
|
||||
|
||||
score: float
|
||||
yseq: List[int]
|
||||
dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None
|
||||
lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtendedHypothesis(Hypothesis):
|
||||
"""Extended hypothesis definition for NSC beam search and mAES.
|
||||
|
||||
Args:
|
||||
: Hypothesis dataclass arguments.
|
||||
dec_out: Decoder output sequence. (B, D_dec)
|
||||
lm_score: Log-probabilities of the LM for given label. (vocab_size)
|
||||
|
||||
"""
|
||||
|
||||
dec_out: torch.Tensor = None
|
||||
lm_score: torch.Tensor = None
|
||||
|
||||
|
||||
class BeamSearchTransducer:
|
||||
"""Beam search implementation for Transducer.
|
||||
|
||||
Args:
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint network module.
|
||||
beam_size: Size of the beam.
|
||||
lm: LM class.
|
||||
lm_weight: LM weight for soft fusion.
|
||||
search_type: Search algorithm to use during inference.
|
||||
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
|
||||
u_max: Maximum expected target sequence length. (ALSD)
|
||||
nstep: Number of maximum expansion steps at each time step. (mAES)
|
||||
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
|
||||
expansion_beta:
|
||||
Number of additional candidates for expanded hypotheses selection. (mAES)
|
||||
score_norm: Normalize final scores by length.
|
||||
nbest: Number of final hypothesis.
|
||||
streaming: Whether to perform chunk-by-chunk beam search.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
joint_network: JointNetwork,
|
||||
beam_size: int,
|
||||
lm: Optional[torch.nn.Module] = None,
|
||||
lm_weight: float = 0.1,
|
||||
search_type: str = "default",
|
||||
max_sym_exp: int = 3,
|
||||
u_max: int = 50,
|
||||
nstep: int = 2,
|
||||
expansion_gamma: float = 2.3,
|
||||
expansion_beta: int = 2,
|
||||
score_norm: bool = False,
|
||||
nbest: int = 1,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
"""Construct a BeamSearchTransducer object."""
|
||||
super().__init__()
|
||||
|
||||
self.decoder = decoder
|
||||
self.joint_network = joint_network
|
||||
|
||||
self.vocab_size = decoder.vocab_size
|
||||
|
||||
assert beam_size <= self.vocab_size, (
|
||||
"beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
|
||||
% (
|
||||
beam_size,
|
||||
self.vocab_size,
|
||||
)
|
||||
)
|
||||
self.beam_size = beam_size
|
||||
|
||||
if search_type == "default":
|
||||
self.search_algorithm = self.default_beam_search
|
||||
elif search_type == "tsd":
|
||||
assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
|
||||
max_sym_exp
|
||||
)
|
||||
self.max_sym_exp = max_sym_exp
|
||||
|
||||
self.search_algorithm = self.time_sync_decoding
|
||||
elif search_type == "alsd":
|
||||
assert not streaming, "ALSD is not available in streaming mode."
|
||||
|
||||
assert u_max >= 0, "u_max should be a positive integer, a portion of max_T."
|
||||
self.u_max = u_max
|
||||
|
||||
self.search_algorithm = self.align_length_sync_decoding
|
||||
elif search_type == "maes":
|
||||
assert self.vocab_size >= beam_size + expansion_beta, (
|
||||
"beam_size (%d) + expansion_beta (%d) "
|
||||
" should be smaller than or equal to vocab size (%d)."
|
||||
% (beam_size, expansion_beta, self.vocab_size)
|
||||
)
|
||||
self.max_candidates = beam_size + expansion_beta
|
||||
|
||||
self.nstep = nstep
|
||||
self.expansion_gamma = expansion_gamma
|
||||
|
||||
self.search_algorithm = self.modified_adaptive_expansion_search
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Specified search type (%s) is not supported." % search_type
|
||||
)
|
||||
|
||||
self.use_lm = lm is not None
|
||||
|
||||
if self.use_lm:
|
||||
assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported."
|
||||
|
||||
self.sos = self.vocab_size - 1
|
||||
|
||||
self.lm = lm
|
||||
self.lm_weight = lm_weight
|
||||
|
||||
self.score_norm = score_norm
|
||||
self.nbest = nbest
|
||||
|
||||
self.reset_inference_cache()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
is_final: bool = True,
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform beam search.
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D_enc)
|
||||
is_final: Whether enc_out is the final chunk of data.
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best decoding results
|
||||
|
||||
"""
|
||||
self.decoder.set_device(enc_out.device)
|
||||
|
||||
hyps = self.search_algorithm(enc_out)
|
||||
|
||||
if is_final:
|
||||
self.reset_inference_cache()
|
||||
|
||||
return self.sort_nbest(hyps)
|
||||
|
||||
self.search_cache = hyps
|
||||
|
||||
return hyps
|
||||
|
||||
def reset_inference_cache(self) -> None:
|
||||
"""Reset cache for decoder scoring and streaming."""
|
||||
self.decoder.score_cache = {}
|
||||
self.search_cache = None
|
||||
|
||||
def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
|
||||
"""Sort in-place hypotheses by score or score given sequence length.
|
||||
|
||||
Args:
|
||||
hyps: Hypothesis.
|
||||
|
||||
Return:
|
||||
hyps: Sorted hypothesis.
|
||||
|
||||
"""
|
||||
if self.score_norm:
|
||||
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
|
||||
else:
|
||||
hyps.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
return hyps[: self.nbest]
|
||||
|
||||
def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
|
||||
"""Recombine hypotheses with same label ID sequence.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
final: Recombined hypotheses.
|
||||
|
||||
"""
|
||||
final = {}
|
||||
|
||||
for hyp in hyps:
|
||||
str_yseq = "_".join(map(str, hyp.yseq))
|
||||
|
||||
if str_yseq in final:
|
||||
final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score)
|
||||
else:
|
||||
final[str_yseq] = hyp
|
||||
|
||||
return [*final.values()]
|
||||
|
||||
def select_k_expansions(
|
||||
self,
|
||||
hyps: List[ExtendedHypothesis],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_logp: torch.Tensor,
|
||||
) -> List[ExtendedHypothesis]:
|
||||
"""Return K hypotheses candidates for expansion from a list of hypothesis.
|
||||
|
||||
K candidates are selected according to the extended hypotheses probabilities
|
||||
and a prune-by-value method. Where K is equal to beam_size + beta.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
topk_idx: Indices of candidates hypothesis.
|
||||
topk_logp: Log-probabilities of candidates hypothesis.
|
||||
|
||||
Returns:
|
||||
k_expansions: Best K expansion hypotheses candidates.
|
||||
|
||||
"""
|
||||
k_expansions = []
|
||||
|
||||
for i, hyp in enumerate(hyps):
|
||||
hyp_i = [
|
||||
(int(k), hyp.score + float(v))
|
||||
for k, v in zip(topk_idx[i], topk_logp[i])
|
||||
]
|
||||
k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
|
||||
|
||||
k_expansions.append(
|
||||
sorted(
|
||||
filter(
|
||||
lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
|
||||
),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
return k_expansions
|
||||
|
||||
def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor:
|
||||
"""Make batch of inputs with left padding for LM scoring.
|
||||
|
||||
Args:
|
||||
hyps_seq: Hypothesis sequences.
|
||||
|
||||
Returns:
|
||||
: Padded batch of sequences.
|
||||
|
||||
"""
|
||||
max_len = max([len(h) for h in hyps_seq])
|
||||
|
||||
return torch.LongTensor(
|
||||
[[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq],
|
||||
device=self.decoder.device,
|
||||
)
|
||||
|
||||
def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Beam search implementation without prefix search.
|
||||
|
||||
Modified from https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
beam_k = min(self.beam_size, (self.vocab_size - 1))
|
||||
max_t = len(enc_out)
|
||||
|
||||
if self.search_cache is not None:
|
||||
kept_hyps = self.search_cache
|
||||
else:
|
||||
kept_hyps = [
|
||||
Hypothesis(
|
||||
score=0.0,
|
||||
yseq=[0],
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
for t in range(max_t):
|
||||
hyps = kept_hyps
|
||||
kept_hyps = []
|
||||
|
||||
while True:
|
||||
max_hyp = max(hyps, key=lambda x: x.score)
|
||||
hyps.remove(max_hyp)
|
||||
|
||||
label = torch.full(
|
||||
(1, 1),
|
||||
max_hyp.yseq[-1],
|
||||
dtype=torch.long,
|
||||
device=self.decoder.device,
|
||||
)
|
||||
dec_out, state = self.decoder.score(
|
||||
label,
|
||||
max_hyp.yseq,
|
||||
max_hyp.dec_state,
|
||||
)
|
||||
|
||||
logp = torch.log_softmax(
|
||||
self.joint_network(enc_out[t : t + 1, :], dec_out),
|
||||
dim=-1,
|
||||
).squeeze(0)
|
||||
top_k = logp[1:].topk(beam_k, dim=-1)
|
||||
|
||||
kept_hyps.append(
|
||||
Hypothesis(
|
||||
score=(max_hyp.score + float(logp[0:1])),
|
||||
yseq=max_hyp.yseq,
|
||||
dec_state=max_hyp.dec_state,
|
||||
lm_state=max_hyp.lm_state,
|
||||
)
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
lm_scores, lm_state = self.lm.score(
|
||||
torch.LongTensor(
|
||||
[self.sos] + max_hyp.yseq[1:], device=self.decoder.device
|
||||
),
|
||||
max_hyp.lm_state,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
lm_state = max_hyp.lm_state
|
||||
|
||||
for logp, k in zip(*top_k):
|
||||
score = max_hyp.score + float(logp)
|
||||
|
||||
if self.use_lm:
|
||||
score += self.lm_weight * lm_scores[k + 1]
|
||||
|
||||
hyps.append(
|
||||
Hypothesis(
|
||||
score=score,
|
||||
yseq=max_hyp.yseq + [int(k + 1)],
|
||||
dec_state=state,
|
||||
lm_state=lm_state,
|
||||
)
|
||||
)
|
||||
|
||||
hyps_max = float(max(hyps, key=lambda x: x.score).score)
|
||||
kept_most_prob = sorted(
|
||||
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
|
||||
key=lambda x: x.score,
|
||||
)
|
||||
if len(kept_most_prob) >= self.beam_size:
|
||||
kept_hyps = kept_most_prob
|
||||
break
|
||||
|
||||
return kept_hyps
|
||||
|
||||
def align_length_sync_decoding(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
) -> List[Hypothesis]:
|
||||
"""Alignment-length synchronous beam search implementation.
|
||||
|
||||
Based on https://ieeexplore.ieee.org/document/9053040
|
||||
|
||||
Args:
|
||||
h: Encoder output sequences. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
t_max = int(enc_out.size(0))
|
||||
u_max = min(self.u_max, (t_max - 1))
|
||||
|
||||
B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))]
|
||||
final = []
|
||||
|
||||
if self.use_lm:
|
||||
B[0].lm_state = self.lm.zero_state()
|
||||
|
||||
for i in range(t_max + u_max):
|
||||
A = []
|
||||
|
||||
B_ = []
|
||||
B_enc_out = []
|
||||
for hyp in B:
|
||||
u = len(hyp.yseq) - 1
|
||||
t = i - u
|
||||
|
||||
if t > (t_max - 1):
|
||||
continue
|
||||
|
||||
B_.append(hyp)
|
||||
B_enc_out.append((t, enc_out[t]))
|
||||
|
||||
if B_:
|
||||
beam_enc_out = torch.stack([b[1] for b in B_enc_out])
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(B_)
|
||||
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([b.yseq for b in B_]),
|
||||
[b.lm_state for b in B_],
|
||||
None,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(B_):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(beam_logp[i, 0])),
|
||||
yseq=hyp.yseq[:],
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
A.append(new_hyp)
|
||||
|
||||
if B_enc_out[i][0] == (t_max - 1):
|
||||
final.append(new_hyp)
|
||||
|
||||
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(logp)),
|
||||
yseq=(hyp.yseq[:] + [int(k)]),
|
||||
dec_state=self.decoder.select_state(beam_state, i),
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
|
||||
new_hyp.lm_state = beam_lm_states[i]
|
||||
|
||||
A.append(new_hyp)
|
||||
|
||||
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
B = self.recombine_hyps(B)
|
||||
|
||||
if final:
|
||||
return final
|
||||
|
||||
return B
|
||||
|
||||
def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Time synchronous beam search implementation.
|
||||
|
||||
Based on https://ieeexplore.ieee.org/document/9053040
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
if self.search_cache is not None:
|
||||
B = self.search_cache
|
||||
else:
|
||||
B = [
|
||||
Hypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
if self.use_lm:
|
||||
B[0].lm_state = self.lm.zero_state()
|
||||
|
||||
for enc_out_t in enc_out:
|
||||
A = []
|
||||
C = B
|
||||
|
||||
enc_out_t = enc_out_t.unsqueeze(0)
|
||||
|
||||
for v in range(self.max_sym_exp):
|
||||
D = []
|
||||
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(C)
|
||||
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(enc_out_t, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
|
||||
|
||||
seq_A = [h.yseq for h in A]
|
||||
|
||||
for i, hyp in enumerate(C):
|
||||
if hyp.yseq not in seq_A:
|
||||
A.append(
|
||||
Hypothesis(
|
||||
score=(hyp.score + float(beam_logp[i, 0])),
|
||||
yseq=hyp.yseq[:],
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dict_pos = seq_A.index(hyp.yseq)
|
||||
|
||||
A[dict_pos].score = np.logaddexp(
|
||||
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
|
||||
)
|
||||
|
||||
if v < (self.max_sym_exp - 1):
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([c.yseq for c in C]),
|
||||
[c.lm_state for c in C],
|
||||
None,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(C):
|
||||
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(logp)),
|
||||
yseq=(hyp.yseq + [int(k)]),
|
||||
dec_state=self.decoder.select_state(beam_state, i),
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
|
||||
new_hyp.lm_state = beam_lm_states[i]
|
||||
|
||||
D.append(new_hyp)
|
||||
|
||||
C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
|
||||
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
|
||||
return B
|
||||
|
||||
def modified_adaptive_expansion_search(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
) -> List[ExtendedHypothesis]:
|
||||
"""Modified version of Adaptive Expansion Search (mAES).
|
||||
|
||||
Based on AES (https://ieeexplore.ieee.org/document/9250505) and
|
||||
NSC (https://arxiv.org/abs/2201.05420).
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D_enc)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
if self.search_cache is not None:
|
||||
kept_hyps = self.search_cache
|
||||
else:
|
||||
init_tokens = [
|
||||
ExtendedHypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(
|
||||
init_tokens,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([h.yseq for h in init_tokens]),
|
||||
[h.lm_state for h in init_tokens],
|
||||
None,
|
||||
)
|
||||
|
||||
lm_state = beam_lm_states[0]
|
||||
lm_score = beam_lm_scores[0]
|
||||
else:
|
||||
lm_state = None
|
||||
lm_score = None
|
||||
|
||||
kept_hyps = [
|
||||
ExtendedHypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.select_state(beam_state, 0),
|
||||
dec_out=beam_dec_out[0],
|
||||
lm_state=lm_state,
|
||||
lm_score=lm_score,
|
||||
)
|
||||
]
|
||||
|
||||
for enc_out_t in enc_out:
|
||||
hyps = kept_hyps
|
||||
kept_hyps = []
|
||||
|
||||
beam_enc_out = enc_out_t.unsqueeze(0)
|
||||
|
||||
list_b = []
|
||||
for n in range(self.nstep):
|
||||
beam_dec_out = torch.stack([h.dec_out for h in hyps])
|
||||
|
||||
beam_logp, beam_idx = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
).topk(self.max_candidates, dim=-1)
|
||||
|
||||
k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp)
|
||||
|
||||
list_exp = []
|
||||
for i, hyp in enumerate(hyps):
|
||||
for k, new_score in k_expansions[i]:
|
||||
new_hyp = ExtendedHypothesis(
|
||||
yseq=hyp.yseq[:],
|
||||
score=new_score,
|
||||
dec_out=hyp.dec_out,
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
lm_score=hyp.lm_score,
|
||||
)
|
||||
|
||||
if k == 0:
|
||||
list_b.append(new_hyp)
|
||||
else:
|
||||
new_hyp.yseq.append(int(k))
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * float(hyp.lm_score[k])
|
||||
|
||||
list_exp.append(new_hyp)
|
||||
|
||||
if not list_exp:
|
||||
kept_hyps = sorted(
|
||||
self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True
|
||||
)[: self.beam_size]
|
||||
|
||||
break
|
||||
else:
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(
|
||||
list_exp,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([h.yseq for h in list_exp]),
|
||||
[h.lm_state for h in list_exp],
|
||||
None,
|
||||
)
|
||||
|
||||
if n < (self.nstep - 1):
|
||||
for i, hyp in enumerate(list_exp):
|
||||
hyp.dec_out = beam_dec_out[i]
|
||||
hyp.dec_state = self.decoder.select_state(beam_state, i)
|
||||
|
||||
if self.use_lm:
|
||||
hyp.lm_state = beam_lm_states[i]
|
||||
hyp.lm_score = beam_lm_scores[i]
|
||||
|
||||
hyps = list_exp[:]
|
||||
else:
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(list_exp):
|
||||
hyp.score += float(beam_logp[i, 0])
|
||||
|
||||
hyp.dec_out = beam_dec_out[i]
|
||||
hyp.dec_state = self.decoder.select_state(beam_state, i)
|
||||
|
||||
if self.use_lm:
|
||||
hyp.lm_state = beam_lm_states[i]
|
||||
hyp.lm_score = beam_lm_scores[i]
|
||||
|
||||
kept_hyps = sorted(
|
||||
self.recombine_hyps(list_b + list_exp),
|
||||
key=lambda x: x.score,
|
||||
reverse=True,
|
||||
)[: self.beam_size]
|
||||
|
||||
return kept_hyps
|
||||
Reference in New Issue
Block a user