diff --git a/finetune/trainer.py b/finetune/trainer.py index bea2eff..fa57bd0 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -5,9 +5,9 @@ from transformers import Trainer from transformers.trainer_pt_utils import nested_detach from transformers.utils import is_sagemaker_mp_enabled from transformers.trainer import * -import deepspeed from transformers.integrations import is_deepspeed_zero3_enabled + class CPMTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): if "labels" in inputs: