diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 6891b33..cc2f2da 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -609,3 +609,79 @@ class Qwen2LM(TransformerLM): # in stream mode, yield token one by one yield top_ids lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + +class CosyVoice3LM(Qwen2LM): + def __init__( + self, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + llm: torch.nn.Module, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + mix_ratio: List[int] = [5, 15], + ): + torch.nn.Module.__init__(self) + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + # 2. build speech token language model related modules + self.sos = 0 + self.eos = 1 + self.task_id = 2 + self.fill_token = 3 + + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 200, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size) + + # 4. sampling method + self.sampling = sampling + self.mix_ratio = mix_ratio + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + uuid: str = '', + ) -> Generator[torch.Tensor, None, None]: + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.llm.model.model.embed_tokens(text) + + # 3. concat llm_input + sos_eos_emb = self.speech_embedding.weight[self.speech_token_size + self.sos].reshape(1, 1, -1) + task_id_emb = self.speech_embedding.weight[self.speech_token_size + self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid): + yield token diff --git a/cosyvoice/tokenizer/tokenizer.py b/cosyvoice/tokenizer/tokenizer.py index 43fb39a..6ecf4ae 100644 --- a/cosyvoice/tokenizer/tokenizer.py +++ b/cosyvoice/tokenizer/tokenizer.py @@ -238,7 +238,7 @@ def get_tokenizer( ) -class QwenTokenizer(): +class CosyVoice2Tokenizer(): def __init__(self, token_path, skip_special_tokens=True): super().__init__() # NOTE: non-chat model, all these special tokens keep randomly initialized. @@ -271,9 +271,57 @@ class QwenTokenizer(): return text +class CosyVoice3Tokenizer(CosyVoice2Tokenizer): + def __init__(self, token_path, skip_special_tokens=True): + # NOTE: non-chat model, all these special tokens keep randomly initialized. + special_tokens = { + 'eos_token': '<|endoftext|>', + 'pad_token': '<|endoftext|>', + 'additional_special_tokens': [ + '<|im_start|>', '<|im_end|>', '<|endofprompt|>', + '[breath]', '', '', '[noise]', + '[laughter]', '[cough]', '[clucking]', '[accent]', + '[quick_breath]', + "", "", + "[hissing]", "[sigh]", "[vocalized-noise]", + "[lipsmack]", "[mn]", "<|endofsystem|>", + "[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]", + "[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]", + "[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]", + "[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]", + "[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]", + "[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]", + "[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]", + "[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]", + "[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]", + "[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]", + "[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]", + "[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]", + "[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]", + "[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]", + "[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]", + "[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]", + "[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]", + "[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]", + "[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]", + "[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]" + ] + } + self.special_tokens = special_tokens + self.tokenizer = AutoTokenizer.from_pretrained(token_path) + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + @lru_cache(maxsize=None) def get_qwen_tokenizer( token_path: str, - skip_special_tokens: bool -) -> QwenTokenizer: - return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) + skip_special_tokens: bool, + version: str = 'cosyvoice2' +): + if version == 'cosyvoice2': + return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) + elif version == 'cosyvoice3': + return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) + else: + raise ValueError