mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
update q_lora code and memory cost with zero3 and offloading (#200)
This commit is contained in:
5
.vscode/settings.json
vendored
Normal file
5
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"githubPullRequests.ignoredPullRequestBranches": [
|
||||
"main"
|
||||
]
|
||||
}
|
||||
@@ -19,7 +19,7 @@ from transformers import AutoModel, AutoTokenizer
|
||||
from dataset import SupervisedDataset, data_collator
|
||||
from trainer import CPMTrainer
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@@ -220,7 +220,13 @@ def train():
|
||||
local_rank = training_args.local_rank
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
ddp = world_size != 1
|
||||
device_map = None
|
||||
if lora_args.q_lora:
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
|
||||
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
|
||||
logging.warning(
|
||||
"FSDP or ZeRO3 are not incompatible with QLoRA."
|
||||
)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
@@ -258,11 +264,15 @@ def train():
|
||||
def get_input_embeddings(self):
|
||||
return self.llm.get_input_embeddings()
|
||||
model.get_input_embeddings = MethodType(get_input_embeddings, model)
|
||||
if lora_args.q_lora:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=training_args.gradient_checkpointing
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
|
||||
model.base_model.resampler.requires_grad_(True)
|
||||
if training_args.tune_vision:
|
||||
model.base_model.vpm.requires_grad_(True)
|
||||
model.base_model.resampler.requires_grad_(True)
|
||||
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
|
||||
if training_args.gradient_checkpointing:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
|
||||
@@ -28,10 +28,10 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--remove_unused_columns false \
|
||||
--label_names "labels" \
|
||||
--prediction_loss_only false \
|
||||
--bf16 true \
|
||||
--bf16_full_eval true \
|
||||
--fp16 false \
|
||||
--fp16_full_eval false \
|
||||
--bf16 false \
|
||||
--bf16_full_eval false \
|
||||
--fp16 true \
|
||||
--fp16_full_eval true \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--tune_vision true \
|
||||
|
||||
@@ -28,8 +28,10 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--remove_unused_columns false \
|
||||
--label_names "labels" \
|
||||
--prediction_loss_only false \
|
||||
--bf16 true \
|
||||
--bf16_full_eval true \
|
||||
--bf16 false \
|
||||
--bf16_full_eval false \
|
||||
--fp16 true \
|
||||
--fp16_full_eval true \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--tune_vision true \
|
||||
|
||||
@@ -93,12 +93,12 @@ model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
|
||||
### Model Fine-tuning Memory Usage Statistics
|
||||
|
||||
The following table presents the memory usage of the model when fine-tuning using NVIDIA A100 (80GiB) GPUs under different numbers of GPUs. The fine-tuning was performed with the DeepSpeed Zero-2 optimization and Gradient Checkpointing techniques, with a maximum length set to 2048 and batch size set to 1.
|
||||
The following table presents the memory usage of the model when fine-tuning using NVIDIA A100 (80GiB) GPUs under different numbers of GPUs. The fine-tuning was performed with the DeepSpeed Zero-3 optimization, Gradient Checkpointing techniques and offloading optimizer as well as parameters memory to cpu, with a maximum length set to 2048 and batch size set to 1. You refer to [deepspeed zero stage](https://huggingface.co/docs/transformers/v4.41.2/en/deepspeed#select-a-zero-stage) to reduce memory cost.
|
||||
|
||||
| Fine-tuning Method | GPUs: 2 | GPUs: 4 | GPUs: 8 |
|
||||
|--------------------|---------|---------|---------|
|
||||
| LoRA Fine-tuning | 31.2 GiB| 29.3 GiB| 28.4GiB |
|
||||
| Full Parameters Fine-tuning | Out of memory | 75.0 GiB | 51.2GiB |
|
||||
| LoRA Fine-tuning | 14.4 GiB| 13.6 GiB| 13.1 GiB |
|
||||
| Full Parameters Fine-tuning | 16.0 GiB | 15.8 GiB | 15.63GiB |
|
||||
|
||||
### Notes
|
||||
- **Fine-tuning Method**: Displays two different fine-tuning strategies, LoRA fine-tuning and Full parameters fine-tuning.
|
||||
|
||||
Reference in New Issue
Block a user