mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
60
funasr_local/text/token_id_converter.py
Normal file
60
funasr_local/text/token_id_converter.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class TokenIDConverter:
|
||||
def __init__(
|
||||
self,
|
||||
token_list: Union[Path, str, Iterable[str]],
|
||||
unk_symbol: str = "<unk>",
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
if isinstance(token_list, (Path, str)):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with token_list.open("r", encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
line = line.rstrip()
|
||||
self.token_list.append(line)
|
||||
|
||||
else:
|
||||
self.token_list: List[str] = list(token_list)
|
||||
self.token_list_repr = ""
|
||||
for i, t in enumerate(self.token_list):
|
||||
if i == 3:
|
||||
break
|
||||
self.token_list_repr += f"{t}, "
|
||||
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
||||
|
||||
self.token2id: Dict[str, int] = {}
|
||||
for i, t in enumerate(self.token_list):
|
||||
if t in self.token2id:
|
||||
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
||||
self.token2id[t] = i
|
||||
|
||||
self.unk_symbol = unk_symbol
|
||||
if self.unk_symbol not in self.token2id:
|
||||
raise RuntimeError(
|
||||
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return len(self.token_list)
|
||||
|
||||
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
||||
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
||||
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
||||
Reference in New Issue
Block a user