mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Update to MiniCPM-V 2.6
This commit is contained in:
@@ -6,6 +6,8 @@ from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union, Literal, Tuple
|
||||
from types import MethodType
|
||||
from torchvision import transforms
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate.utils import DistributedType
|
||||
@@ -130,6 +132,18 @@ def make_supervised_data_module(
|
||||
)
|
||||
|
||||
|
||||
def build_transform():
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def get_parameter_number(model):
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
@@ -248,10 +262,11 @@ def train():
|
||||
else:
|
||||
batch_vision = False
|
||||
|
||||
transform_func = build_transform()
|
||||
data_module = make_supervised_data_module(
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
transform=model.transform,
|
||||
transform=transform_func,
|
||||
data_collator=data_collator,
|
||||
slice_config=slice_config,
|
||||
llm_type=llm_type,
|
||||
|
||||
Reference in New Issue
Block a user