Model Fine-tuning Memory Usage Statistics (#160)

This commit is contained in:
qianyu chen
2024-05-28 11:41:27 +08:00
committed by GitHub
parent 7e12387362
commit f592fedb2e
4 changed files with 30 additions and 10 deletions

View File

@@ -3,6 +3,7 @@ import json
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, List, Optional, Union, Literal, Tuple
from types import MethodType
import torch
@@ -133,6 +134,7 @@ def make_supervised_data_module(
patch_size=14,
query_nums=64,
batch_vision=False,
max_length=2048,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = SupervisedDataset
@@ -169,7 +171,7 @@ def make_supervised_data_module(
return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
data_collator= partial(data_collator, max_length=max_length),
)
@@ -287,6 +289,7 @@ def train():
patch_size=model.config.patch_size,
query_nums=model.config.query_num,
batch_vision=batch_vision,
max_length=training_args.model_max_length,
)
trainer = CPMTrainer(