mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add cosyvoice2
This commit is contained in:
@@ -2,6 +2,8 @@ import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from whisper.tokenizer import Tokenizer
|
||||
|
||||
import tiktoken
|
||||
@@ -234,3 +236,37 @@ def get_tokenizer(
|
||||
return Tokenizer(
|
||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||
)
|
||||
|
||||
|
||||
class QwenTokenizer():
|
||||
def __init__(self, token_path, skip_special_tokens=True):
|
||||
special_tokens = {
|
||||
'eos_token': '<|endoftext|>',
|
||||
'pad_token': '<|endoftext|>',
|
||||
'additional_special_tokens': [
|
||||
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||
'[quick_breath]',
|
||||
]
|
||||
}
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||
self.tokenizer.add_special_tokens(special_tokens)
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
tokens = self.tokenizer([text], return_tensors="pt")
|
||||
tokens = tokens["input_ids"][0].cpu().tolist()
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
tokens = torch.tensor(tokens, dtype=torch.int64)
|
||||
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
||||
return text
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_qwen_tokenizer(
|
||||
token_path: str,
|
||||
skip_special_tokens: bool
|
||||
) -> QwenTokenizer:
|
||||
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||
Reference in New Issue
Block a user