update q_lora code and memory cost with zero3 and offloading (#200)

This commit is contained in:
qianyu chen
2024-06-05 14:13:02 +08:00
committed by GitHub
parent 6bd877970d
commit 74278de0f4
5 changed files with 30 additions and 13 deletions

View File

@@ -19,7 +19,7 @@ from transformers import AutoModel, AutoTokenizer
from dataset import SupervisedDataset, data_collator
from trainer import CPMTrainer
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
@dataclass
class ModelArguments:
@@ -220,7 +220,13 @@ def train():
local_rank = training_args.local_rank
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
device_map = None
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP or ZeRO3 are not incompatible with QLoRA."
)
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
@@ -258,11 +264,15 @@ def train():
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
model.get_input_embeddings = MethodType(get_input_embeddings, model)
if lora_args.q_lora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
model = get_peft_model(model, lora_config)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
model.base_model.resampler.requires_grad_(True)
if training_args.tune_vision:
model.base_model.vpm.requires_grad_(True)
model.base_model.resampler.requires_grad_(True)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()