mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add func export_codec_vllm
This commit is contained in:
@@ -126,7 +126,7 @@ class CosyVoice:
|
||||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -149,6 +149,8 @@ class CosyVoice2(CosyVoice):
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if use_vllm:
|
||||
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from contextlib import nullcontext
|
||||
import uuid
|
||||
@@ -317,6 +318,38 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def export_codec_vllm(self, model_path):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
vocab_size = self.llm.speech_embedding.num_embeddings
|
||||
feature_size = self.llm.speech_embedding.embedding_dim
|
||||
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
dtype = torch.bfloat16
|
||||
new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
||||
with torch.no_grad():
|
||||
new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight
|
||||
new_lm_head.bias[:vocab_size] = self.llm.llm_decoder.bias
|
||||
new_lm_head.weight[vocab_size:] = 0
|
||||
new_lm_head.bias[vocab_size:] = 0
|
||||
self.llm.llm.model.lm_head = new_lm_head
|
||||
new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||
with torch.no_grad():
|
||||
new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight
|
||||
new_codec_embed.weight[vocab_size:] = 0
|
||||
self.llm.llm.model.set_input_embeddings(new_codec_embed)
|
||||
self.llm.llm.model.to(self.device)
|
||||
self.llm.llm.model.to(dtype)
|
||||
tmp_vocab_size = self.llm.llm.model.config.vocab_size
|
||||
tmp_tie_embedding = self.llm.llm.model.config.tie_word_embeddings
|
||||
self.llm.llm.model.config.vocab_size = pad_vocab_size
|
||||
self.llm.llm.model.config.tie_word_embeddings = False
|
||||
self.llm.llm.model.config.use_bias = True
|
||||
self.llm.llm.model.save_pretrained(model_path)
|
||||
self.llm.llm.model.config.vocab_size = tmp_vocab_size
|
||||
self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
|
||||
Reference in New Issue
Block a user