mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
61
funasr_local/modules/scorers/length_bonus.py
Normal file
61
funasr_local/modules/scorers/length_bonus.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Length bonus module."""
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class LengthBonus(BatchScorerInterface):
|
||||
"""Length bonus in beam search."""
|
||||
|
||||
def __init__(self, n_vocab: int):
|
||||
"""Initialize class.
|
||||
|
||||
Args:
|
||||
n_vocab (int): The number of tokens in vocabulary for beam search
|
||||
|
||||
"""
|
||||
self.n = n_vocab
|
||||
|
||||
def score(self, y, state, x):
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): 2D encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
torch.float32 scores for next token (n_vocab)
|
||||
and None
|
||||
|
||||
"""
|
||||
return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
|
||||
|
||||
def batch_score(
|
||||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
return (
|
||||
torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
|
||||
ys.shape[0], self.n
|
||||
),
|
||||
None,
|
||||
)
|
||||
Reference in New Issue
Block a user