add func export_codec_vllm

This commit is contained in:
雾聪
2025-02-25 15:01:29 +08:00
parent fd45708e4b
commit c37c00ff94
2 changed files with 36 additions and 1 deletions

View File

@@ -126,7 +126,7 @@ class CosyVoice:
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
self.instruct = True if '-Instruct' in model_dir else False self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
self.fp16 = fp16 self.fp16 = fp16
@@ -149,6 +149,8 @@ class CosyVoice2(CosyVoice):
self.model.load('{}/llm.pt'.format(model_dir), self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir), '{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir)) '{}/hift.pt'.format(model_dir))
if use_vllm:
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
if load_jit: if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt: if load_trt:

View File

@@ -17,6 +17,7 @@ import torch
import numpy as np import numpy as np
import threading import threading
import time import time
from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from contextlib import nullcontext from contextlib import nullcontext
import uuid import uuid
@@ -317,6 +318,38 @@ class CosyVoice2Model(CosyVoiceModel):
def load_jit(self, flow_encoder_model): def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
def export_codec_vllm(self, model_path):
if os.path.exists(model_path):
return
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
vocab_size = self.llm.speech_embedding.num_embeddings
feature_size = self.llm.speech_embedding.embedding_dim
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
dtype = torch.bfloat16
new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
with torch.no_grad():
new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight
new_lm_head.bias[:vocab_size] = self.llm.llm_decoder.bias
new_lm_head.weight[vocab_size:] = 0
new_lm_head.bias[vocab_size:] = 0
self.llm.llm.model.lm_head = new_lm_head
new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
with torch.no_grad():
new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight
new_codec_embed.weight[vocab_size:] = 0
self.llm.llm.model.set_input_embeddings(new_codec_embed)
self.llm.llm.model.to(self.device)
self.llm.llm.model.to(dtype)
tmp_vocab_size = self.llm.llm.model.config.vocab_size
tmp_tie_embedding = self.llm.llm.model.config.tie_word_embeddings
self.llm.llm.model.config.vocab_size = pad_vocab_size
self.llm.llm.model.config.tie_word_embeddings = False
self.llm.llm.model.config.use_bias = True
self.llm.llm.model.save_pretrained(model_path)
self.llm.llm.model.config.vocab_size = tmp_vocab_size
self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
tts_mel, _ = self.flow.inference(token=token.to(self.device), tts_mel, _ = self.flow.inference(token=token.to(self.device),