Modify eval_mm for MiniCPM-V 2.6

This commit is contained in:
Haoyu Li
2024-08-30 18:18:22 +00:00
parent ab1141ee45
commit 59224808a1
69 changed files with 8231 additions and 1818 deletions

View File

@@ -9,6 +9,60 @@ import time
import numpy as np
import validators
import mimetypes
import multiprocessing as mp
from .misc import toliststr
from .vlm import decode_base64_to_image_file
def decode_img_omni(tup):
root, im, p = tup
images = toliststr(im)
paths = toliststr(p)
if len(images) > 1 and len(paths) == 1:
paths = [osp.splitext(p)[0] + f'_{i}' + osp.splitext(p)[1] for i in range(len(images))]
assert len(images) == len(paths)
paths = [osp.join(root, p) for p in paths]
for p, im in zip(paths, images):
if osp.exists(p):
continue
if isinstance(im, str) and len(im) > 64:
decode_base64_to_image_file(im, p)
return paths
def localize_df(data, dname, nproc=32):
assert 'image' in data
indices = list(data['index'])
indices_str = [str(x) for x in indices]
images = list(data['image'])
image_map = {x: y for x, y in zip(indices_str, images)}
root = LMUDataRoot()
root = osp.join(root, 'images', dname)
os.makedirs(root, exist_ok=True)
if 'image_path' in data:
img_paths = list(data['image_path'])
else:
img_paths = []
for i in indices_str:
if len(image_map[i]) <= 64:
idx = image_map[i]
assert idx in image_map and len(image_map[idx]) > 64
img_paths.append(f'{idx}.jpg')
else:
img_paths.append(f'{i}.jpg')
tups = [(root, im, p) for p, im in zip(img_paths, images)]
pool = mp.Pool(32)
ret = pool.map(decode_img_omni, tups)
pool.close()
data.pop('image')
if 'image_path' not in data:
data['image_path'] = [x[0] if len(x) == 1 else x for x in ret]
return data
def LMUDataRoot():
@@ -17,10 +71,9 @@ def LMUDataRoot():
home = osp.expanduser('~')
root = osp.join(home, 'LMUData')
os.makedirs(root, exist_ok=True)
# root = './LMUData'
# os.makedirs(root, exist_ok=True)
return root
def MMBenchOfficialServer(dataset_name):
root = LMUDataRoot()
@@ -92,7 +145,7 @@ def dump(data, f, **kwargs):
return handlers[suffix](data, f, **kwargs)
def load(f):
def load(f, fmt=None):
def load_pkl(pth):
return pickle.load(open(pth, 'rb'))
@@ -117,6 +170,9 @@ def load(f):
return pd.read_csv(f, sep='\t')
handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
if fmt is not None:
return handlers[fmt](f)
suffix = f.split('.')[-1]
return handlers[suffix](f)
@@ -134,9 +190,24 @@ def download_file(url, filename=None):
if filename is None:
filename = url.split('/')[-1]
with DownloadProgressBar(unit='B', unit_scale=True,
miniters=1, desc=url.split('/')[-1]) as t:
urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
# If HF_ENDPOINT is set, replace huggingface.co with it
if 'huggingface.co' in url and os.environ.get('HF_ENDPOINT', '') != '':
url = url.replace('huggingface.co', os.environ['HF_ENDPOINT'].split('://')[1])
try:
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
except:
# Handle Failed Downloads from huggingface.co
if 'huggingface.co' in url:
url_new = url.replace('huggingface.co', 'hf-mirror.com')
try:
os.system(f'wget {url_new} -O {filename}')
except:
raise Exception(f'Failed to download {url}')
else:
raise Exception(f'Failed to download {url}')
return filename
@@ -210,7 +281,7 @@ def last_modified(pth):
def parse_file(s):
if osp.exists(s):
if osp.exists(s) and s != '.':
assert osp.isfile(s)
suffix = osp.splitext(s)[1].lower()
mime = mimetypes.types_map.get(suffix, 'unknown')
@@ -228,3 +299,20 @@ def parse_file(s):
return ('url', s)
else:
return (None, s)
def file_size(f, unit='GB'):
stats = os.stat(f)
div_map = {
'GB': 2 ** 30,
'MB': 2 ** 20,
'KB': 2 ** 10,
}
return stats.st_size / div_map[unit]
def parquet_to_tsv(file_path):
data = pd.read_parquet(file_path)
pth = '/'.join(file_path.split('/')[:-1])
data_name = file_path.split('/')[-1].split('.')[0]
data.to_csv(osp.join(pth, f'{data_name}.tsv'), sep='\t', index=False)

View File

@@ -18,8 +18,8 @@ from multiprocessing import Pool, current_process
from tqdm import tqdm
import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate_formats, tabulate
from tabulate import tabulate
from json import JSONDecoder
from huggingface_hub import scan_cache_dir
from sty import fg, bg, ef, rs
@@ -71,7 +71,7 @@ def bincount(lst):
bins[item] += 1
return bins
def get_cache_path(repo_id):
def get_cache_path(repo_id, branch=None):
hf_cache_info = scan_cache_dir()
repos = list(hf_cache_info.repos)
repo = None
@@ -82,6 +82,8 @@ def get_cache_path(repo_id):
if repo is None:
return None
revs = list(repo.revisions)
if branch is not None:
revs = [r for r in revs if r.refs == frozenset({branch})]
rev2keep, last_modified = None, 0
for rev in revs:
if rev.last_modified > last_modified:
@@ -189,3 +191,26 @@ def version_cmp(v1, v2, op='eq'):
import operator
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
def toliststr(s):
if isinstance(s, str) and (s[0] == '[') and (s[-1] == ']'):
return [str(x) for x in eval(s)]
elif isinstance(s, str):
return [s]
elif isinstance(s, list):
return [str(x) for x in s]
raise NotImplementedError
def extract_json_objects(text, decoder=JSONDecoder()):
pos = 0
while True:
match = text.find('{', pos)
if match == -1: break
try:
result, index = decoder.raw_decode(text[match:])
yield result
pos = match + index
except ValueError:
pos = match + 1

View File

@@ -7,10 +7,53 @@ from uuid import uuid4
import os.path as osp
import base64
from PIL import Image
from .file import load, dump
import sys
Image.MAX_IMAGE_PIXELS = 1e9
def rescale_img(img, tgt=None):
assert isinstance(tgt, tuple) and -1 in tgt
w, h = img.size
if tgt[0] != -1:
new_w, new_h = tgt[0], int(tgt[0] / w * h)
elif tgt[1] != -1:
new_w, new_h = int(tgt[1] / h * w), tgt[1]
img = img.resize((new_w, new_h))
return img
def concat_images_vlmeval(images, target_size=-1, mode='h', return_image=False):
from .file import md5
ims = [Image.open(im) for im in images]
if target_size != -1:
ims = [
rescale_img(im, (-1, target_size) if mode == 'h' else (target_size, -1))
for im in ims
]
ws, hs = [x.width for x in ims], [x.height for x in ims]
if mode == 'h':
new_w, new_h = sum(ws), max(hs)
dst = Image.new('RGB', (new_w, new_h))
for i, im in enumerate(ims):
dst.paste(im, (sum(ws[:i]), 0))
elif mode == 'v':
new_w, new_h = max(ws), sum(hs)
dst = Image.new('RGB', (new_w, new_h))
for i, im in enumerate(ims):
dst.paste(im, (sum(ws[:i], 0)))
if return_image:
return dst
else:
_str = '\n'.join(images)
str_md5 = md5(_str)
tgt = osp.join('/tmp', str_md5 + '.jpg')
dst.save(tgt)
return tgt
def mmqa_display(question, target_size=512):
question = {k.lower(): v for k, v in question.items()}
keys = list(question.keys())
@@ -41,14 +84,12 @@ def encode_image_to_base64(img, target_size=-1):
# else, will set the max_size ot (target_size, target_size)
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
tmp = osp.join('/tmp', str(uuid4()) + '.jpg')
if target_size > 0:
img.thumbnail((target_size, target_size))
img.save(tmp)
with open(tmp, 'rb') as image_file:
image_data = image_file.read()
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG')
image_data = img_buffer.getvalue()
ret = base64.b64encode(image_data).decode('utf-8')
os.remove(tmp)
return ret
@@ -110,6 +151,7 @@ def circular_pred(df, extract_func=None):
extract_func = lambda x: x # noqa: E731
df = df.sort_values('index')
from vlmeval.utils import can_infer_option
shift = int(1e6)
choices = [extract_func(x) for x in df['prediction']]
@@ -118,9 +160,12 @@ def circular_pred(df, extract_func=None):
valid_map = {i: True for i in pred_map if i < 1e6}
for i in df['index']:
if i >= shift and pred_map[i] and pred_map[i - shift]:
if (
pred_map[i] not in list(string.ascii_uppercase) or # noqa: W504
pred_map[i - shift] not in list(string.ascii_uppercase)
if pred_map[i] not in list(
string.ascii_uppercase
) or pred_map[ # noqa: W504
i - shift
] not in list(
string.ascii_uppercase
):
valid_map[i % shift] = False