mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Model Fine-tuning Memory Usage Statistics (#160)
This commit is contained in:
@@ -3,6 +3,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union, Literal, Tuple
|
||||
from types import MethodType
|
||||
import torch
|
||||
@@ -133,6 +134,7 @@ def make_supervised_data_module(
|
||||
patch_size=14,
|
||||
query_nums=64,
|
||||
batch_vision=False,
|
||||
max_length=2048,
|
||||
) -> Dict:
|
||||
"""Make dataset and collator for supervised fine-tuning."""
|
||||
dataset_cls = SupervisedDataset
|
||||
@@ -169,7 +171,7 @@ def make_supervised_data_module(
|
||||
return dict(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=data_collator,
|
||||
data_collator= partial(data_collator, max_length=max_length),
|
||||
)
|
||||
|
||||
|
||||
@@ -287,6 +289,7 @@ def train():
|
||||
patch_size=model.config.patch_size,
|
||||
query_nums=model.config.query_num,
|
||||
batch_vision=batch_vision,
|
||||
max_length=training_args.model_max_length,
|
||||
)
|
||||
|
||||
trainer = CPMTrainer(
|
||||
|
||||
Reference in New Issue
Block a user