From c37c00ff94dbaf89016d13c01780aa56ba7afead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E8=81=AA?= Date: Tue, 25 Feb 2025 15:01:29 +0800 Subject: [PATCH] add func export_codec_vllm --- cosyvoice/cli/cosyvoice.py | 4 +++- cosyvoice/cli/model.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index e2d62e2..b8fe756 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -126,7 +126,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 @@ -149,6 +149,8 @@ class CosyVoice2(CosyVoice): self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) + if use_vllm: + self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model'])) if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9ebf8cb..115d7e1 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -17,6 +17,7 @@ import torch import numpy as np import threading import time +from torch import nn from torch.nn import functional as F from contextlib import nullcontext import uuid @@ -317,6 +318,38 @@ class CosyVoice2Model(CosyVoiceModel): def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder + + def export_codec_vllm(self, model_path): + if os.path.exists(model_path): + return + pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64 + vocab_size = self.llm.speech_embedding.num_embeddings + feature_size = self.llm.speech_embedding.embedding_dim + pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to + + dtype = torch.bfloat16 + new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True) + with torch.no_grad(): + new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight + new_lm_head.bias[:vocab_size] = self.llm.llm_decoder.bias + new_lm_head.weight[vocab_size:] = 0 + new_lm_head.bias[vocab_size:] = 0 + self.llm.llm.model.lm_head = new_lm_head + new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size) + with torch.no_grad(): + new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight + new_codec_embed.weight[vocab_size:] = 0 + self.llm.llm.model.set_input_embeddings(new_codec_embed) + self.llm.llm.model.to(self.device) + self.llm.llm.model.to(dtype) + tmp_vocab_size = self.llm.llm.model.config.vocab_size + tmp_tie_embedding = self.llm.llm.model.config.tie_word_embeddings + self.llm.llm.model.config.vocab_size = pad_vocab_size + self.llm.llm.model.config.tie_word_embeddings = False + self.llm.llm.model.config.use_bias = True + self.llm.llm.model.save_pretrained(model_path) + self.llm.llm.model.config.vocab_size = tmp_vocab_size + self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): tts_mel, _ = self.flow.inference(token=token.to(self.device),