mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 09:49:20 +08:00
Update trainer.py
This commit is contained in:
@@ -7,7 +7,7 @@ from transformers.trainer_pt_utils import nested_detach
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from transformers.trainer import *
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
class CPMTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
@@ -189,8 +189,7 @@ class CPMTrainer(Trainer):
|
||||
`torch.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
if is_sagemaker_mp_enabled():
|
||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
|
||||
Reference in New Issue
Block a user