mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 10:19:18 +08:00
update q_lora code and memory cost with zero3 and offloading (#200)
This commit is contained in:
@@ -19,7 +19,7 @@ from transformers import AutoModel, AutoTokenizer
|
||||
from dataset import SupervisedDataset, data_collator
|
||||
from trainer import CPMTrainer
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@@ -220,7 +220,13 @@ def train():
|
||||
local_rank = training_args.local_rank
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
ddp = world_size != 1
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
||||
device_map = None
|
||||
if lora_args.q_lora:
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
||||
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
|
||||
logging.warning(
|
||||
"FSDP or ZeRO3 are not incompatible with QLoRA."
|
||||
)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
@@ -258,11 +264,15 @@ def train():
|
||||
def get_input_embeddings(self):
|
||||
return self.llm.get_input_embeddings()
|
||||
model.get_input_embeddings = MethodType(get_input_embeddings, model)
|
||||
if lora_args.q_lora:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=training_args.gradient_checkpointing
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
|
||||
model.base_model.resampler.requires_grad_(True)
|
||||
if training_args.tune_vision:
|
||||
model.base_model.vpm.requires_grad_(True)
|
||||
model.base_model.resampler.requires_grad_(True)
|
||||
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
|
||||
if training_args.gradient_checkpointing:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user