add cosyvoice2

This commit is contained in:
lyuxiang.lx
2024-12-11 16:14:19 +08:00
parent cde3cec6fa
commit 3e381002d7
9 changed files with 484 additions and 32 deletions

View File

@@ -2,6 +2,8 @@ import base64
import os
from functools import lru_cache
from typing import Optional
import torch
from transformers import AutoTokenizer
from whisper.tokenizer import Tokenizer
import tiktoken
@@ -234,3 +236,37 @@ def get_tokenizer(
return Tokenizer(
encoding=encoding, num_languages=num_languages, language=language, task=task
)
class QwenTokenizer():
def __init__(self, token_path, skip_special_tokens=True):
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
]
}
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
self.tokenizer.add_special_tokens(special_tokens)
self.skip_special_tokens = skip_special_tokens
def encode(self, text, **kwargs):
tokens = self.tokenizer([text], return_tensors="pt")
tokens = tokens["input_ids"][0].cpu().tolist()
return tokens
def decode(self, tokens):
tokens = torch.tensor(tokens, dtype=torch.int64)
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
return text
@lru_cache(maxsize=None)
def get_qwen_tokenizer(
token_path: str,
skip_special_tokens: bool
) -> QwenTokenizer:
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)