Update to MiniCPM-V 2.6

This commit is contained in:
yiranyyu
2024-08-06 12:26:49 +08:00
parent 1cb882d473
commit b1a15299e6
28 changed files with 3692 additions and 191 deletions

View File

@@ -6,6 +6,8 @@ from dataclasses import dataclass, field
from functools import partial
from typing import Dict, List, Optional, Union, Literal, Tuple
from types import MethodType
from torchvision import transforms
import torch
import transformers
from accelerate.utils import DistributedType
@@ -130,6 +132,18 @@ def make_supervised_data_module(
)
def build_transform():
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
return transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
),
]
)
def get_parameter_number(model):
trainable_params, all_param = 0, 0
for param in model.parameters():
@@ -248,10 +262,11 @@ def train():
else:
batch_vision = False
transform_func = build_transform()
data_module = make_supervised_data_module(
tokenizer=tokenizer,
data_args=data_args,
transform=model.transform,
transform=transform_func,
data_collator=data_collator,
slice_config=slice_config,
llm_type=llm_type,