Merge pull request #257 from KabakaWilliam/main

imported deepseed twice. Removed unnecesary extra import
This commit is contained in:
Hongji Zhu
2024-06-17 20:32:58 +08:00
committed by GitHub

View File

@@ -5,9 +5,9 @@ from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
from transformers.trainer import *
import deepspeed
from transformers.integrations import is_deepspeed_zero3_enabled
class CPMTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs: