Update trainer.py

This commit is contained in:
qianyu chen
2025-09-12 15:53:48 +08:00
committed by GitHub
parent c821cbd7c8
commit e41152f89c

View File

@@ -7,7 +7,7 @@ from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
from transformers.trainer import *
from transformers.integrations import is_deepspeed_zero3_enabled
from typing import Dict, List, Optional, Tuple
class CPMTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
return (loss, logits, labels)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
@@ -190,7 +190,6 @@ class CPMTrainer(Trainer):
"""
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)