imported deepseed twice. Removed unnecesary extra import

This commit is contained in:
KabakaWilliam
2024-06-12 12:47:26 +01:00
parent fccdf02cbc
commit 7017aa1754

View File

@@ -5,9 +5,9 @@ from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
from transformers.trainer import *
import deepspeed
from transformers.integrations import is_deepspeed_zero3_enabled
class CPMTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs: