add vllm inference

This commit is contained in:
lyuxiang.lx
2025-05-30 07:22:35 +00:00
parent 9f55c5af8f
commit 6dd68b9d5e
6 changed files with 105 additions and 64 deletions

View File

@@ -140,7 +140,7 @@ class CosyVoice:
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.instruct = True if '-Instruct' in model_dir else False self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
self.fp16 = fp16 self.fp16 = fp16

View File

@@ -59,9 +59,6 @@ class CosyVoiceModel:
self.stream_scale_factor = 1 self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
for _ in range(trt_concurrent):
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.lock = threading.Lock() self.lock = threading.Lock()
# dict used to store session related variable # dict used to store session related variable
self.tts_speech_token_dict = {} self.tts_speech_token_dict = {}
@@ -69,7 +66,6 @@ class CosyVoiceModel:
self.mel_overlap_dict = {} self.mel_overlap_dict = {}
self.flow_cache_dict = {} self.flow_cache_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.trt_context_dict = {}
def load(self, llm_model, flow_model, hift_model): def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -98,7 +94,7 @@ class CosyVoiceModel:
with open(flow_decoder_estimator_model, 'rb') as f: with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent) self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device)
def get_trt_kwargs(self): def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
@@ -125,7 +121,8 @@ class CosyVoiceModel:
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)): embedding=llm_embedding.to(self.device),
uuid=uuid):
self.tts_speech_token_dict[uuid].append(i) self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True self.llm_end_dict[uuid] = True
@@ -180,13 +177,11 @@ class CosyVoiceModel:
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread # this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1()) this_uuid = str(uuid.uuid1())
this_trt_context = self.trt_context_pool.get()
with self.lock: with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
self.trt_context_dict[this_uuid] = this_trt_context
if source_speech_token.shape[1] == 0: if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else: else:
@@ -240,8 +235,6 @@ class CosyVoiceModel:
self.mel_overlap_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid)
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
self.trt_context_dict.pop(this_uuid)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
@@ -273,22 +266,28 @@ class CosyVoice2Model(CosyVoiceModel):
self.speech_window = np.hamming(2 * self.source_cache_len) self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related # rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
for _ in range(trt_concurrent):
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.lock = threading.Lock() self.lock = threading.Lock()
# dict used to store session related variable # dict used to store session related variable
self.tts_speech_token_dict = {} self.tts_speech_token_dict = {}
self.llm_end_dict = {} self.llm_end_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.trt_context_dict = {}
def load_jit(self, flow_encoder_model): def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
def load_vllm(self, model_dir):
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
from vllm import EngineArgs, LLMEngine
engine_args = EngineArgs(model=model_dir,
skip_tokenizer_init=True,
enable_prompt_embeds=True,
gpu_memory_utilization=0.2)
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
del self.llm.llm.model.model.layers
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]: with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device), tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device), prompt_token=prompt_token.to(self.device),
@@ -330,11 +329,9 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread # this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1()) this_uuid = str(uuid.uuid1())
this_trt_context = self.trt_context_pool.get()
with self.lock: with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None self.hift_cache_dict[this_uuid] = None
self.trt_context_dict[this_uuid] = this_trt_context
if source_speech_token.shape[1] == 0: if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else: else:
@@ -388,8 +385,6 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict.pop(this_uuid) self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid)
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
self.trt_context_dict.pop(this_uuid)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()

View File

@@ -16,6 +16,7 @@ import threading
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM from matcha.models.components.flow_matching import BASECFM
from cosyvoice.utils.common import set_all_random_seed
class ConditionalCFM(BASECFM): class ConditionalCFM(BASECFM):
@@ -32,7 +33,6 @@ class ConditionalCFM(BASECFM):
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here # Just change the architecture of the estimator here
self.estimator = estimator self.estimator = estimator
self.lock = threading.Lock()
@torch.inference_mode() @torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)): def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
@@ -127,26 +127,27 @@ class ConditionalCFM(BASECFM):
if isinstance(self.estimator, torch.nn.Module): if isinstance(self.estimator, torch.nn.Module):
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
else: else:
estimator, trt_engine = self.estimator.acquire_estimator() [estimator, stream], trt_engine = self.estimator.acquire_estimator()
estimator.set_input_shape('x', (2, 80, x.size(2))) with stream:
estimator.set_input_shape('mask', (2, 1, x.size(2))) estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2))) estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('t', (2,)) estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('spks', (2, 80)) estimator.set_input_shape('t', (2,))
estimator.set_input_shape('cond', (2, 80, x.size(2))) estimator.set_input_shape('spks', (2, 80))
data_ptrs = [x.contiguous().data_ptr(), estimator.set_input_shape('cond', (2, 80, x.size(2)))
mask.contiguous().data_ptr(), data_ptrs = [x.contiguous().data_ptr(),
mu.contiguous().data_ptr(), mask.contiguous().data_ptr(),
t.contiguous().data_ptr(), mu.contiguous().data_ptr(),
spks.contiguous().data_ptr(), t.contiguous().data_ptr(),
cond.contiguous().data_ptr(), spks.contiguous().data_ptr(),
x.data_ptr()] cond.contiguous().data_ptr(),
for i, j in enumerate(data_ptrs): x.data_ptr()]
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) for i, j in enumerate(data_ptrs):
# run trt engine estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True # run trt engine
torch.cuda.current_stream().synchronize() assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
self.estimator.release_estimator(estimator) torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator, stream)
return x return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -194,6 +195,7 @@ class ConditionalCFM(BASECFM):
class CausalConditionalCFM(ConditionalCFM): class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
set_all_random_seed(0)
self.rand_noise = torch.randn([1, 80, 50 * 300]) self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode() @torch.inference_mode()

View File

@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import queue
import random import random
import time
import threading
from typing import Dict, Optional, Callable, List, Generator from typing import Dict, Optional, Callable, List, Generator
import torch import torch
from torch import nn from torch import nn
@@ -170,6 +173,7 @@ class TransformerLM(torch.nn.Module):
sampling: int = 25, sampling: int = 25,
max_token_text_ratio: float = 20, max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2, min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
device = text.device device = text.device
text = torch.concat([prompt_text, text], dim=1) text = torch.concat([prompt_text, text], dim=1)
@@ -270,7 +274,6 @@ class Qwen2LM(TransformerLM):
self.llm_input_size = llm_input_size self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size self.speech_token_size = speech_token_size
# 2. build speech token language model related modules # 2. build speech token language model related modules
self.sos_eos = 0 self.sos_eos = 0
self.task_id = 1 self.task_id = 1
@@ -292,6 +295,11 @@ class Qwen2LM(TransformerLM):
# 4. sampling method # 4. sampling method
self.sampling = sampling self.sampling = sampling
self.mix_ratio = mix_ratio self.mix_ratio = mix_ratio
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {}
self.lock = threading.Lock()
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len): def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
lm_target, lm_input = [], [] lm_target, lm_input = [], []
@@ -382,6 +390,7 @@ class Qwen2LM(TransformerLM):
sampling: int = 25, sampling: int = 25,
max_token_text_ratio: float = 20, max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2, min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
device = text.device device = text.device
text = torch.concat([prompt_text, text], dim=1) text = torch.concat([prompt_text, text], dim=1)
@@ -402,22 +411,55 @@ class Qwen2LM(TransformerLM):
max_len = int((text_len - prompt_text_len) * max_token_text_ratio) max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode # 5. step by step decode
out_tokens = [] for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
cache = None yield token
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input, @torch.inference_mode()
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
cache=cache) if hasattr(self, 'vllm'):
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) from vllm import SamplingParams, RequestOutput
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() sampling_params = SamplingParams(top_k=sampling,
if top_ids == self.speech_token_size: stop_token_ids=self.stop_token_ids,
break min_tokens=min_len,
if top_ids > self.speech_token_size: max_tokens=max_len)
continue with self.lock:
# in stream mode, yield token one by one self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
yield top_ids self.vllm_output_queue[uuid] = queue.Queue()
out_tokens.append(top_ids) out_tokens = []
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) while True:
with self.lock:
if self.vllm_output_queue[uuid].empty() is True:
request_outputs: List[RequestOutput] = self.vllm.step()
for request_output in request_outputs:
top_ids = list(request_output.outputs[0].token_ids)[-1]
self.vllm_output_queue[request_output.request_id].put(top_ids)
if self.vllm_output_queue[uuid].empty() is False:
top_ids = self.vllm_output_queue[uuid].get()
if top_ids in self.stop_token_ids:
break
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
time.sleep(0.001)
with self.lock:
self.vllm_output_queue.pop(uuid)
else:
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
@torch.inference_mode() @torch.inference_mode()
def inference_bistream( def inference_bistream(

View File

@@ -169,17 +169,18 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
class TrtContextWrapper: class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1): def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue() self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine self.trt_engine = trt_engine
for _ in range(trt_concurrent): for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context() trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put(trt_context) self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self): def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context): def release_estimator(self, context, stream):
self.trt_context_pool.put(context) self.trt_context_pool.put([context, stream])

View File

@@ -58,7 +58,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
network = builder.create_network(network_flags) network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger) parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config() config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if fp16: if fp16:
config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile() profile = builder.create_optimization_profile()
@@ -122,6 +122,7 @@ def export_cosyvoice2_vllm(model, model_path, device):
model.llm.model.config.tie_word_embeddings = False model.llm.model.config.tie_word_embeddings = False
model.llm.model.config.use_bias = True model.llm.model.config.use_bias = True
model.llm.model.save_pretrained(model_path) model.llm.model.save_pretrained(model_path)
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
model.llm.model.config.vocab_size = tmp_vocab_size model.llm.model.config.vocab_size = tmp_vocab_size
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
model.llm.model.set_input_embeddings(embed_tokens) model.llm.model.set_input_embeddings(embed_tokens)