mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
Update finetune.py
This commit is contained in:
@@ -66,42 +66,6 @@ class LoraArguments:
|
||||
lora_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||
lora_layers_to_transform: Optional[List[int]] = None
|
||||
lora_layers_pattern: Optional[str] = None
|
||||
|
||||
def maybe_zero_3(param):
|
||||
if hasattr(param, "ds_id"):
|
||||
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
|
||||
with zero.GatheredParameters([param]):
|
||||
param = param.data.detach().cpu().clone()
|
||||
else:
|
||||
param = param.detach().cpu().clone()
|
||||
return param
|
||||
|
||||
|
||||
# Borrowed from peft.utils.get_peft_model_state_dict
|
||||
def get_peft_state_maybe_zero_3(named_params, bias):
|
||||
if bias == "none":
|
||||
to_return = {k: t for k, t in named_params if "lora_" in k}
|
||||
elif bias == "all":
|
||||
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
||||
elif bias == "lora_only":
|
||||
to_return = {}
|
||||
maybe_lora_bias = {}
|
||||
lora_bias_names = set()
|
||||
for k, t in named_params:
|
||||
if "lora_" in k:
|
||||
to_return[k] = t
|
||||
bias_name = k.split("lora_")[0] + "bias"
|
||||
lora_bias_names.add(bias_name)
|
||||
elif "bias" in k:
|
||||
maybe_lora_bias[k] = t
|
||||
for k, t in maybe_lora_bias:
|
||||
if bias_name in lora_bias_names:
|
||||
to_return[bias_name] = t
|
||||
else:
|
||||
raise NotImplementedError
|
||||
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
|
||||
return to_return
|
||||
|
||||
|
||||
local_rank = None
|
||||
def rank0_print(*args):
|
||||
@@ -111,18 +75,8 @@ def rank0_print(*args):
|
||||
|
||||
def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
|
||||
"""Collects the state dict and dump to disk."""
|
||||
# check if zero3 mode enabled
|
||||
if deepspeed.is_deepspeed_zero3_enabled():
|
||||
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
|
||||
else:
|
||||
if trainer.args.use_lora:
|
||||
state_dict = get_peft_state_maybe_zero_3(
|
||||
trainer.model.named_parameters(), bias
|
||||
)
|
||||
else:
|
||||
state_dict = trainer.model.state_dict()
|
||||
if trainer.args.should_save and trainer.args.local_rank == 0:
|
||||
trainer._save(output_dir, state_dict=state_dict)
|
||||
trainer.save_model(output_dir,)
|
||||
|
||||
|
||||
def make_supervised_data_module(
|
||||
|
||||
Reference in New Issue
Block a user