support video sft and auto save and load all files

This commit is contained in:
fzc8578
2025-01-11 13:50:36 +08:00
parent 8464c94a7b
commit c5e82b1bc7
4 changed files with 170 additions and 22 deletions

View File

@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
return (loss, logits, labels)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: int=None) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
@@ -245,6 +245,9 @@ class CPMTrainer(Trainer):
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
if getattr(self.model, "processor") is not None:
self.model.processor.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))