mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
348
funasr_local/modules/beam_search/batch_beam_search.py
Normal file
348
funasr_local/modules/beam_search/batch_beam_search.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Parallel beam search module."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import NamedTuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search import BeamSearch
|
||||
from funasr_local.modules.beam_search.beam_search import Hypothesis
|
||||
|
||||
|
||||
class BatchHypothesis(NamedTuple):
|
||||
"""Batchfied/Vectorized hypothesis data type."""
|
||||
|
||||
yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
|
||||
score: torch.Tensor = torch.tensor([]) # (batch,)
|
||||
length: torch.Tensor = torch.tensor([]) # (batch,)
|
||||
scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
|
||||
states: Dict[str, Dict] = dict()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return a batch size."""
|
||||
return len(self.length)
|
||||
|
||||
|
||||
class BatchBeamSearch(BeamSearch):
|
||||
"""Batch beam search implementation."""
|
||||
|
||||
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
|
||||
"""Convert list to batch."""
|
||||
if len(hyps) == 0:
|
||||
return BatchHypothesis()
|
||||
return BatchHypothesis(
|
||||
yseq=pad_sequence(
|
||||
[h.yseq for h in hyps], batch_first=True, padding_value=self.eos
|
||||
),
|
||||
length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
|
||||
score=torch.tensor([h.score for h in hyps]),
|
||||
scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
|
||||
states={k: [h.states[k] for h in hyps] for k in self.scorers},
|
||||
)
|
||||
|
||||
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
|
||||
return BatchHypothesis(
|
||||
yseq=hyps.yseq[ids],
|
||||
score=hyps.score[ids],
|
||||
length=hyps.length[ids],
|
||||
scores={k: v[ids] for k, v in hyps.scores.items()},
|
||||
states={
|
||||
k: [self.scorers[k].select_state(v, i) for i in ids]
|
||||
for k, v in hyps.states.items()
|
||||
},
|
||||
)
|
||||
|
||||
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
|
||||
return Hypothesis(
|
||||
yseq=hyps.yseq[i, : hyps.length[i]],
|
||||
score=hyps.score[i],
|
||||
scores={k: v[i] for k, v in hyps.scores.items()},
|
||||
states={
|
||||
k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
|
||||
},
|
||||
)
|
||||
|
||||
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
|
||||
"""Revert batch to list."""
|
||||
return [
|
||||
Hypothesis(
|
||||
yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
|
||||
score=batch_hyps.score[i],
|
||||
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
|
||||
states={
|
||||
k: v.select_state(batch_hyps.states[k], i)
|
||||
for k, v in self.scorers.items()
|
||||
},
|
||||
)
|
||||
for i in range(len(batch_hyps.length))
|
||||
]
|
||||
|
||||
def batch_beam(
|
||||
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Batch-compute topk full token ids and partial token ids.
|
||||
|
||||
Args:
|
||||
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
||||
Its shape is `(n_beam, self.vocab_size)`.
|
||||
ids (torch.Tensor): The partial token ids to compute topk.
|
||||
Its shape is `(n_beam, self.pre_beam_size)`.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
The topk full (prev_hyp, new_token) ids
|
||||
and partial (prev_hyp, new_token) ids.
|
||||
Their shapes are all `(self.beam_size,)`
|
||||
|
||||
"""
|
||||
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
|
||||
# Because of the flatten above, `top_ids` is organized as:
|
||||
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
|
||||
# where V is `self.n_vocab` and K is `self.beam_size`
|
||||
prev_hyp_ids = top_ids // self.n_vocab
|
||||
new_token_ids = top_ids % self.n_vocab
|
||||
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
|
||||
|
||||
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
|
||||
"""Get an initial hypothesis data.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoder output feature
|
||||
|
||||
Returns:
|
||||
Hypothesis: The initial hypothesis.
|
||||
|
||||
"""
|
||||
init_states = dict()
|
||||
init_scores = dict()
|
||||
for k, d in self.scorers.items():
|
||||
init_states[k] = d.batch_init_state(x)
|
||||
init_scores[k] = 0.0
|
||||
return self.batchfy(
|
||||
[
|
||||
Hypothesis(
|
||||
score=0.0,
|
||||
scores=init_scores,
|
||||
states=init_states,
|
||||
yseq=torch.tensor([self.sos], device=x.device),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def score_full(
|
||||
self, hyp: BatchHypothesis, x: torch.Tensor
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.full_scorers.items():
|
||||
scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
|
||||
return scores, states
|
||||
|
||||
def score_partial(
|
||||
self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
ids (torch.Tensor): 2D tensor of new partial tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.part_scorers.items():
|
||||
scores[k], states[k] = d.batch_score_partial(
|
||||
hyp.yseq, ids, hyp.states[k], x
|
||||
)
|
||||
return scores, states
|
||||
|
||||
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
||||
"""Merge states for new hypothesis.
|
||||
|
||||
Args:
|
||||
states: states of `self.full_scorers`
|
||||
part_states: states of `self.part_scorers`
|
||||
part_idx (int): The new token id for `part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are states of the scorers.
|
||||
|
||||
"""
|
||||
new_states = dict()
|
||||
for k, v in states.items():
|
||||
new_states[k] = v
|
||||
for k, v in part_states.items():
|
||||
new_states[k] = v
|
||||
return new_states
|
||||
|
||||
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
|
||||
"""Search new tokens for running hypotheses and encoded speech x.
|
||||
|
||||
Args:
|
||||
running_hyps (BatchHypothesis): Running hypotheses on beam
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
|
||||
Returns:
|
||||
BatchHypothesis: Best sorted hypotheses
|
||||
|
||||
"""
|
||||
n_batch = len(running_hyps)
|
||||
part_ids = None # no pre-beam
|
||||
# batch scoring
|
||||
weighted_scores = torch.zeros(
|
||||
n_batch, self.n_vocab, dtype=x.dtype, device=x.device
|
||||
)
|
||||
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
|
||||
for k in self.full_scorers:
|
||||
weighted_scores += self.weights[k] * scores[k]
|
||||
# partial scoring
|
||||
if self.do_pre_beam:
|
||||
pre_beam_scores = (
|
||||
weighted_scores
|
||||
if self.pre_beam_score_key == "full"
|
||||
else scores[self.pre_beam_score_key]
|
||||
)
|
||||
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
|
||||
# NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
|
||||
# full-size score matrices, which has non-zero scores for part_ids and zeros
|
||||
# for others.
|
||||
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
|
||||
for k in self.part_scorers:
|
||||
weighted_scores += self.weights[k] * part_scores[k]
|
||||
# add previous hyp scores
|
||||
weighted_scores += running_hyps.score.to(
|
||||
dtype=x.dtype, device=x.device
|
||||
).unsqueeze(1)
|
||||
|
||||
# TODO(karita): do not use list. use batch instead
|
||||
# see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
|
||||
# update hyps
|
||||
best_hyps = []
|
||||
prev_hyps = self.unbatchfy(running_hyps)
|
||||
for (
|
||||
full_prev_hyp_id,
|
||||
full_new_token_id,
|
||||
part_prev_hyp_id,
|
||||
part_new_token_id,
|
||||
) in zip(*self.batch_beam(weighted_scores, part_ids)):
|
||||
prev_hyp = prev_hyps[full_prev_hyp_id]
|
||||
best_hyps.append(
|
||||
Hypothesis(
|
||||
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
|
||||
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
|
||||
scores=self.merge_scores(
|
||||
prev_hyp.scores,
|
||||
{k: v[full_prev_hyp_id] for k, v in scores.items()},
|
||||
full_new_token_id,
|
||||
{k: v[part_prev_hyp_id] for k, v in part_scores.items()},
|
||||
part_new_token_id,
|
||||
),
|
||||
states=self.merge_states(
|
||||
{
|
||||
k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
|
||||
for k, v in states.items()
|
||||
},
|
||||
{
|
||||
k: self.part_scorers[k].select_state(
|
||||
v, part_prev_hyp_id, part_new_token_id
|
||||
)
|
||||
for k, v in part_states.items()
|
||||
},
|
||||
part_new_token_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
return self.batchfy(best_hyps)
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
i: int,
|
||||
maxlen: int,
|
||||
maxlenratio: float,
|
||||
running_hyps: BatchHypothesis,
|
||||
ended_hyps: List[Hypothesis],
|
||||
) -> BatchHypothesis:
|
||||
"""Perform post-processing of beam search iterations.
|
||||
|
||||
Args:
|
||||
i (int): The length of hypothesis tokens.
|
||||
maxlen (int): The maximum length of tokens in beam search.
|
||||
maxlenratio (int): The maximum length ratio in beam search.
|
||||
running_hyps (BatchHypothesis): The running hypotheses in beam search.
|
||||
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
||||
|
||||
Returns:
|
||||
BatchHypothesis: The new running hypotheses.
|
||||
|
||||
"""
|
||||
n_batch = running_hyps.yseq.shape[0]
|
||||
logging.debug(f"the number of running hypothes: {n_batch}")
|
||||
if self.token_list is not None:
|
||||
logging.debug(
|
||||
"best hypo: "
|
||||
+ "".join(
|
||||
[
|
||||
self.token_list[x]
|
||||
for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
|
||||
]
|
||||
)
|
||||
)
|
||||
# add eos in the final loop to avoid that there are no ended hyps
|
||||
if i == maxlen - 1:
|
||||
logging.info("adding <eos> in the last position in the loop")
|
||||
yseq_eos = torch.cat(
|
||||
(
|
||||
running_hyps.yseq,
|
||||
torch.full(
|
||||
(n_batch, 1),
|
||||
self.eos,
|
||||
device=running_hyps.yseq.device,
|
||||
dtype=torch.int64,
|
||||
),
|
||||
),
|
||||
1,
|
||||
)
|
||||
running_hyps.yseq.resize_as_(yseq_eos)
|
||||
running_hyps.yseq[:] = yseq_eos
|
||||
running_hyps.length[:] = yseq_eos.shape[1]
|
||||
|
||||
# add ended hypotheses to a final list, and removed them from current hypotheses
|
||||
# (this will be a probmlem, number of hyps < beam)
|
||||
is_eos = (
|
||||
running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
|
||||
== self.eos
|
||||
)
|
||||
for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
|
||||
hyp = self._select(running_hyps, b)
|
||||
ended_hyps.append(hyp)
|
||||
remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1)
|
||||
return self._batch_select(running_hyps, remained_ids)
|
||||
Reference in New Issue
Block a user