Update trainer.py

This commit is contained in:
qianyu chen
2025-09-12 15:53:48 +08:00
committed by GitHub
parent c821cbd7c8
commit e41152f89c

View File

@@ -7,7 +7,7 @@ from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from transformers.trainer import * from transformers.trainer import *
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from typing import Dict, List, Optional, Tuple
class CPMTrainer(Trainer): class CPMTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
return (loss, logits, labels) return (loss, logits, labels)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor:
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
@@ -190,7 +190,6 @@ class CPMTrainer(Trainer):
""" """
model.train() model.train()
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)