mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add func export_codec_vllm
This commit is contained in:
@@ -126,7 +126,7 @@ class CosyVoice:
|
|||||||
|
|
||||||
class CosyVoice2(CosyVoice):
|
class CosyVoice2(CosyVoice):
|
||||||
|
|
||||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
|
||||||
self.instruct = True if '-Instruct' in model_dir else False
|
self.instruct = True if '-Instruct' in model_dir else False
|
||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
@@ -149,6 +149,8 @@ class CosyVoice2(CosyVoice):
|
|||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
|
if use_vllm:
|
||||||
|
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
|
||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
if load_trt:
|
if load_trt:
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
import uuid
|
import uuid
|
||||||
@@ -317,6 +318,38 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
def load_jit(self, flow_encoder_model):
|
def load_jit(self, flow_encoder_model):
|
||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
|
def export_codec_vllm(self, model_path):
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
return
|
||||||
|
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||||
|
vocab_size = self.llm.speech_embedding.num_embeddings
|
||||||
|
feature_size = self.llm.speech_embedding.embedding_dim
|
||||||
|
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||||
|
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight
|
||||||
|
new_lm_head.bias[:vocab_size] = self.llm.llm_decoder.bias
|
||||||
|
new_lm_head.weight[vocab_size:] = 0
|
||||||
|
new_lm_head.bias[vocab_size:] = 0
|
||||||
|
self.llm.llm.model.lm_head = new_lm_head
|
||||||
|
new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||||
|
with torch.no_grad():
|
||||||
|
new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight
|
||||||
|
new_codec_embed.weight[vocab_size:] = 0
|
||||||
|
self.llm.llm.model.set_input_embeddings(new_codec_embed)
|
||||||
|
self.llm.llm.model.to(self.device)
|
||||||
|
self.llm.llm.model.to(dtype)
|
||||||
|
tmp_vocab_size = self.llm.llm.model.config.vocab_size
|
||||||
|
tmp_tie_embedding = self.llm.llm.model.config.tie_word_embeddings
|
||||||
|
self.llm.llm.model.config.vocab_size = pad_vocab_size
|
||||||
|
self.llm.llm.model.config.tie_word_embeddings = False
|
||||||
|
self.llm.llm.model.config.use_bias = True
|
||||||
|
self.llm.llm.model.save_pretrained(model_path)
|
||||||
|
self.llm.llm.model.config.vocab_size = tmp_vocab_size
|
||||||
|
self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||||
|
|
||||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
|
|||||||
Reference in New Issue
Block a user