mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 09:49:20 +08:00
Update trainer.py
This commit is contained in:
@@ -7,7 +7,7 @@ from transformers.trainer_pt_utils import nested_detach
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from transformers.trainer import *
|
from transformers.trainer import *
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
class CPMTrainer(Trainer):
|
class CPMTrainer(Trainer):
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
|
|||||||
|
|
||||||
return (loss, logits, labels)
|
return (loss, logits, labels)
|
||||||
|
|
||||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs.
|
Perform a training step on a batch of inputs.
|
||||||
|
|
||||||
@@ -190,7 +190,6 @@ class CPMTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user