From 8d9c86a919cb1b0b6ff46a0c17752062b61ce010 Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Fri, 31 May 2024 12:44:21 +0800 Subject: [PATCH] Update zero3 code and OOM FQAs (#188) --- finetune/finetune.py | 8 +++++ finetune/finetune_ds.sh | 2 ++ finetune/finetune_lora.sh | 2 ++ finetune/readme.md | 56 +++++++++++++++++++++++++++++++ finetune/trainer.py | 70 +++++++++++++++++++++++++++++++++++---- 5 files changed, 131 insertions(+), 7 deletions(-) diff --git a/finetune/finetune.py b/finetune/finetune.py index c86916c..c0200a4 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -50,6 +50,8 @@ class TrainingArguments(transformers.TrainingArguments): tune_llm: Optional[bool] = field(default=True) llm_type: str = field(default="minicpm") use_lora: Optional[bool] = field(default=False) + max_slice_nums: Optional[int] = field(default=9) + scale_resolution: Optional[int] = field(default=448) @dataclass @@ -272,11 +274,17 @@ def train(): rank0_print(f'llm_type={llm_type}') + # Load data if hasattr(model.config, "slice_config"): + model.config.slice_config.max_slice_nums = training_args.max_slice_nums + model.config.slice_config.scale_resolution = training_args.scale_resolution slice_config = model.config.slice_config.to_dict() else: + model.config.max_slice_nums = training_args.max_slice_nums + model.config.scale_resolution = training_args.scale_resolution slice_config = model.config.to_dict() + if hasattr(model.config, "batch_vision_input"): batch_vision = model.config.batch_vision_input else: diff --git a/finetune/finetune_ds.sh b/finetune/finetune_ds.sh index 45c00fe..716239a 100644 --- a/finetune/finetune_ds.sh +++ b/finetune/finetune_ds.sh @@ -37,6 +37,8 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --tune_vision true \ --tune_llm true \ --model_max_length 2048 \ + --max_slice_nums 9 \ + --scale_resolution 448 \ --max_steps 10000 \ --eval_steps 1000 \ --output_dir output/output_minicpmv2 \ diff --git a/finetune/finetune_lora.sh b/finetune/finetune_lora.sh index deba0d5..b05b0f9 100644 --- a/finetune/finetune_lora.sh +++ b/finetune/finetune_lora.sh @@ -37,6 +37,8 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --use_lora true \ --lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj)" \ --model_max_length 2048 \ + --max_slice_nums 9 \ + --scale_resolution 448 \ --max_steps 10000 \ --eval_steps 1000 \ --output_dir output/output_minicpmv2_lora \ diff --git a/finetune/readme.md b/finetune/readme.md index f8ab49c..a962622 100644 --- a/finetune/readme.md +++ b/finetune/readme.md @@ -108,6 +108,62 @@ The following table presents the memory usage of the model when fine-tuning usin ### Finetuning FAQs +
+Q:When you encounter Out of Memory (OOM) issues during training large models, you can try the following methods to resolve or mitigate the issue: + +A:When you face Out of Memory (OOM) issues during training large models, the following strategies may help resolve or mitigate the problem: +#### Adjust Model Hyperparameters +- **Reduce `max_model_length`**: Decreasing the maximum sequence length the model processes can significantly reduce the memory required for each operation. For example, reducing the maximum length from 2048 to 1200 or another value suitable for your dataset. +``` +--model_max_length 1200 + +``` +- **Lower `batch_size`**: Reducing the amount of data processed in each batch helps decrease memory consumption. +``` +--batch_size 1 + ``` +- **Lower image resolution**: If your model processes image data, reducing the input resolution of images can effectively decrease memory usage. +``` +--scale_resolution 448 +``` +- **Reduce the number of slices (`slice`)**: When handling large datasets such as large images files, reducing the number of slices processed each time can lower memory requirements. +``` +--max_slice_nums 9 +``` + +#### Reduce Training Model Parameters +- **Do not train VPM (Visual Processing Module)**: You can adjust hyperparameters in the finetune script to opt out of training the visual processing module to save memory. +``` +--tune_vision false +``` +- **Use LoRA finetuning**: Refer to the [LoRA finetuning](#LoRA-finetuning) section. + +#### Optimize with DeepSpeed +- **Configure DeepSpeed Zero Stage 2**: Use the following configuration to offload optimizer parameters to the CPU, reducing memory pressure on the GPU: + ```json + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + } + } +- **Configure DeepSpeed Zero Stage 3**:Further offload model parameters and optimizer parameters to the CPU, further reducing GPU memory usage: +```json +"zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + } +} +``` +You can visit [huggingface deepspeed](https://huggingface.co/docs/transformers/deepspeed) to find out more about how to use DeepSpeed. +
Q: Encounter an error while using the AutoPeftModelForCausalLM to load a checkpoint that has undergone lora fine-tuning diff --git a/finetune/trainer.py b/finetune/trainer.py index 773e358..fbb8c89 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - import torch import torch.nn as nn +import deepspeed from transformers import Trainer from transformers.trainer_pt_utils import nested_detach from transformers.utils import is_sagemaker_mp_enabled - +from transformers.trainer import * +import deepspeed +from transformers.integrations import is_deepspeed_zero3_enabled class CPMTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): @@ -13,11 +14,19 @@ class CPMTrainer(Trainer): labels = inputs.pop("labels") else: labels = None - if not self. args.use_lora: - outputs = self.model(data = inputs, use_cache=False) + self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device) + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0): + if not self.args.use_lora: + outputs = self.model(data = inputs, use_cache=False) + else: + outputs = self.model.base_model(data = inputs, use_cache=False) else: - outputs = self.model.base_model(data = inputs, use_cache=False) - + if not self.args.use_lora: + outputs = self.model(data = inputs, use_cache=False) + else: + outputs = self.model.base_model(data = inputs, use_cache=False) + if labels is not None: # Flatten the tokens loss_fct = nn.CrossEntropyLoss() @@ -165,3 +174,50 @@ class CPMTrainer(Trainer): logits = logits[0] return (loss, logits, labels) + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + inputs = self._prepare_inputs(inputs) + + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + + del inputs + torch.cuda.empty_cache() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0): + self.accelerator.backward(loss) + else: + self.accelerator.backward(loss) + + return loss.detach() / self.args.gradient_accumulation_steps + +