mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 09:49:20 +08:00
Update zero3 code and OOM FQAs (#188)
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import deepspeed
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_pt_utils import nested_detach
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
|
||||
from transformers.trainer import *
|
||||
import deepspeed
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
class CPMTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
@@ -13,11 +14,19 @@ class CPMTrainer(Trainer):
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
if not self. args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
|
||||
if not self.args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
else:
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
else:
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
|
||||
if not self.args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
else:
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
|
||||
if labels is not None:
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
@@ -165,3 +174,50 @@ class CPMTrainer(Trainer):
|
||||
logits = logits[0]
|
||||
|
||||
return (loss, logits, labels)
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
The model to train.
|
||||
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument `labels`. Check your model's documentation for all accepted arguments.
|
||||
|
||||
Return:
|
||||
`torch.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
del inputs
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
if self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
|
||||
self.accelerator.backward(loss)
|
||||
else:
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user