mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-V 2.6
This commit is contained in:
@@ -9,6 +9,60 @@ import time
|
||||
import numpy as np
|
||||
import validators
|
||||
import mimetypes
|
||||
import multiprocessing as mp
|
||||
from .misc import toliststr
|
||||
from .vlm import decode_base64_to_image_file
|
||||
|
||||
|
||||
def decode_img_omni(tup):
|
||||
root, im, p = tup
|
||||
images = toliststr(im)
|
||||
paths = toliststr(p)
|
||||
if len(images) > 1 and len(paths) == 1:
|
||||
paths = [osp.splitext(p)[0] + f'_{i}' + osp.splitext(p)[1] for i in range(len(images))]
|
||||
|
||||
assert len(images) == len(paths)
|
||||
paths = [osp.join(root, p) for p in paths]
|
||||
for p, im in zip(paths, images):
|
||||
if osp.exists(p):
|
||||
continue
|
||||
if isinstance(im, str) and len(im) > 64:
|
||||
decode_base64_to_image_file(im, p)
|
||||
return paths
|
||||
|
||||
|
||||
def localize_df(data, dname, nproc=32):
|
||||
assert 'image' in data
|
||||
indices = list(data['index'])
|
||||
indices_str = [str(x) for x in indices]
|
||||
images = list(data['image'])
|
||||
image_map = {x: y for x, y in zip(indices_str, images)}
|
||||
|
||||
root = LMUDataRoot()
|
||||
root = osp.join(root, 'images', dname)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
if 'image_path' in data:
|
||||
img_paths = list(data['image_path'])
|
||||
else:
|
||||
img_paths = []
|
||||
for i in indices_str:
|
||||
if len(image_map[i]) <= 64:
|
||||
idx = image_map[i]
|
||||
assert idx in image_map and len(image_map[idx]) > 64
|
||||
img_paths.append(f'{idx}.jpg')
|
||||
else:
|
||||
img_paths.append(f'{i}.jpg')
|
||||
|
||||
tups = [(root, im, p) for p, im in zip(img_paths, images)]
|
||||
|
||||
pool = mp.Pool(32)
|
||||
ret = pool.map(decode_img_omni, tups)
|
||||
pool.close()
|
||||
data.pop('image')
|
||||
if 'image_path' not in data:
|
||||
data['image_path'] = [x[0] if len(x) == 1 else x for x in ret]
|
||||
return data
|
||||
|
||||
|
||||
def LMUDataRoot():
|
||||
@@ -17,10 +71,9 @@ def LMUDataRoot():
|
||||
home = osp.expanduser('~')
|
||||
root = osp.join(home, 'LMUData')
|
||||
os.makedirs(root, exist_ok=True)
|
||||
# root = './LMUData'
|
||||
# os.makedirs(root, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def MMBenchOfficialServer(dataset_name):
|
||||
root = LMUDataRoot()
|
||||
|
||||
@@ -92,7 +145,7 @@ def dump(data, f, **kwargs):
|
||||
return handlers[suffix](data, f, **kwargs)
|
||||
|
||||
|
||||
def load(f):
|
||||
def load(f, fmt=None):
|
||||
def load_pkl(pth):
|
||||
return pickle.load(open(pth, 'rb'))
|
||||
|
||||
@@ -117,6 +170,9 @@ def load(f):
|
||||
return pd.read_csv(f, sep='\t')
|
||||
|
||||
handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
|
||||
if fmt is not None:
|
||||
return handlers[fmt](f)
|
||||
|
||||
suffix = f.split('.')[-1]
|
||||
return handlers[suffix](f)
|
||||
|
||||
@@ -134,9 +190,24 @@ def download_file(url, filename=None):
|
||||
if filename is None:
|
||||
filename = url.split('/')[-1]
|
||||
|
||||
with DownloadProgressBar(unit='B', unit_scale=True,
|
||||
miniters=1, desc=url.split('/')[-1]) as t:
|
||||
urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
|
||||
# If HF_ENDPOINT is set, replace huggingface.co with it
|
||||
if 'huggingface.co' in url and os.environ.get('HF_ENDPOINT', '') != '':
|
||||
url = url.replace('huggingface.co', os.environ['HF_ENDPOINT'].split('://')[1])
|
||||
|
||||
try:
|
||||
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
|
||||
urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
|
||||
except:
|
||||
# Handle Failed Downloads from huggingface.co
|
||||
if 'huggingface.co' in url:
|
||||
url_new = url.replace('huggingface.co', 'hf-mirror.com')
|
||||
try:
|
||||
os.system(f'wget {url_new} -O {filename}')
|
||||
except:
|
||||
raise Exception(f'Failed to download {url}')
|
||||
else:
|
||||
raise Exception(f'Failed to download {url}')
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
@@ -210,7 +281,7 @@ def last_modified(pth):
|
||||
|
||||
|
||||
def parse_file(s):
|
||||
if osp.exists(s):
|
||||
if osp.exists(s) and s != '.':
|
||||
assert osp.isfile(s)
|
||||
suffix = osp.splitext(s)[1].lower()
|
||||
mime = mimetypes.types_map.get(suffix, 'unknown')
|
||||
@@ -228,3 +299,20 @@ def parse_file(s):
|
||||
return ('url', s)
|
||||
else:
|
||||
return (None, s)
|
||||
|
||||
|
||||
def file_size(f, unit='GB'):
|
||||
stats = os.stat(f)
|
||||
div_map = {
|
||||
'GB': 2 ** 30,
|
||||
'MB': 2 ** 20,
|
||||
'KB': 2 ** 10,
|
||||
}
|
||||
return stats.st_size / div_map[unit]
|
||||
|
||||
|
||||
def parquet_to_tsv(file_path):
|
||||
data = pd.read_parquet(file_path)
|
||||
pth = '/'.join(file_path.split('/')[:-1])
|
||||
data_name = file_path.split('/')[-1].split('.')[0]
|
||||
data.to_csv(osp.join(pth, f'{data_name}.tsv'), sep='\t', index=False)
|
||||
|
||||
@@ -18,8 +18,8 @@ from multiprocessing import Pool, current_process
|
||||
from tqdm import tqdm
|
||||
import datetime
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from tabulate import tabulate_formats, tabulate
|
||||
from tabulate import tabulate
|
||||
from json import JSONDecoder
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from sty import fg, bg, ef, rs
|
||||
|
||||
@@ -71,7 +71,7 @@ def bincount(lst):
|
||||
bins[item] += 1
|
||||
return bins
|
||||
|
||||
def get_cache_path(repo_id):
|
||||
def get_cache_path(repo_id, branch=None):
|
||||
hf_cache_info = scan_cache_dir()
|
||||
repos = list(hf_cache_info.repos)
|
||||
repo = None
|
||||
@@ -82,6 +82,8 @@ def get_cache_path(repo_id):
|
||||
if repo is None:
|
||||
return None
|
||||
revs = list(repo.revisions)
|
||||
if branch is not None:
|
||||
revs = [r for r in revs if r.refs == frozenset({branch})]
|
||||
rev2keep, last_modified = None, 0
|
||||
for rev in revs:
|
||||
if rev.last_modified > last_modified:
|
||||
@@ -189,3 +191,26 @@ def version_cmp(v1, v2, op='eq'):
|
||||
import operator
|
||||
op_func = getattr(operator, op)
|
||||
return op_func(version.parse(v1), version.parse(v2))
|
||||
|
||||
|
||||
def toliststr(s):
|
||||
if isinstance(s, str) and (s[0] == '[') and (s[-1] == ']'):
|
||||
return [str(x) for x in eval(s)]
|
||||
elif isinstance(s, str):
|
||||
return [s]
|
||||
elif isinstance(s, list):
|
||||
return [str(x) for x in s]
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def extract_json_objects(text, decoder=JSONDecoder()):
|
||||
pos = 0
|
||||
while True:
|
||||
match = text.find('{', pos)
|
||||
if match == -1: break
|
||||
try:
|
||||
result, index = decoder.raw_decode(text[match:])
|
||||
yield result
|
||||
pos = match + index
|
||||
except ValueError:
|
||||
pos = match + 1
|
||||
|
||||
@@ -7,10 +7,53 @@ from uuid import uuid4
|
||||
import os.path as osp
|
||||
import base64
|
||||
from PIL import Image
|
||||
from .file import load, dump
|
||||
import sys
|
||||
|
||||
Image.MAX_IMAGE_PIXELS = 1e9
|
||||
|
||||
|
||||
def rescale_img(img, tgt=None):
|
||||
assert isinstance(tgt, tuple) and -1 in tgt
|
||||
w, h = img.size
|
||||
if tgt[0] != -1:
|
||||
new_w, new_h = tgt[0], int(tgt[0] / w * h)
|
||||
elif tgt[1] != -1:
|
||||
new_w, new_h = int(tgt[1] / h * w), tgt[1]
|
||||
img = img.resize((new_w, new_h))
|
||||
return img
|
||||
|
||||
|
||||
def concat_images_vlmeval(images, target_size=-1, mode='h', return_image=False):
|
||||
from .file import md5
|
||||
|
||||
ims = [Image.open(im) for im in images]
|
||||
if target_size != -1:
|
||||
ims = [
|
||||
rescale_img(im, (-1, target_size) if mode == 'h' else (target_size, -1))
|
||||
for im in ims
|
||||
]
|
||||
|
||||
ws, hs = [x.width for x in ims], [x.height for x in ims]
|
||||
if mode == 'h':
|
||||
new_w, new_h = sum(ws), max(hs)
|
||||
dst = Image.new('RGB', (new_w, new_h))
|
||||
for i, im in enumerate(ims):
|
||||
dst.paste(im, (sum(ws[:i]), 0))
|
||||
elif mode == 'v':
|
||||
new_w, new_h = max(ws), sum(hs)
|
||||
dst = Image.new('RGB', (new_w, new_h))
|
||||
for i, im in enumerate(ims):
|
||||
dst.paste(im, (sum(ws[:i], 0)))
|
||||
if return_image:
|
||||
return dst
|
||||
else:
|
||||
_str = '\n'.join(images)
|
||||
str_md5 = md5(_str)
|
||||
tgt = osp.join('/tmp', str_md5 + '.jpg')
|
||||
dst.save(tgt)
|
||||
return tgt
|
||||
|
||||
|
||||
def mmqa_display(question, target_size=512):
|
||||
question = {k.lower(): v for k, v in question.items()}
|
||||
keys = list(question.keys())
|
||||
@@ -41,14 +84,12 @@ def encode_image_to_base64(img, target_size=-1):
|
||||
# else, will set the max_size ot (target_size, target_size)
|
||||
if img.mode in ('RGBA', 'P'):
|
||||
img = img.convert('RGB')
|
||||
tmp = osp.join('/tmp', str(uuid4()) + '.jpg')
|
||||
if target_size > 0:
|
||||
img.thumbnail((target_size, target_size))
|
||||
img.save(tmp)
|
||||
with open(tmp, 'rb') as image_file:
|
||||
image_data = image_file.read()
|
||||
img_buffer = io.BytesIO()
|
||||
img.save(img_buffer, format='JPEG')
|
||||
image_data = img_buffer.getvalue()
|
||||
ret = base64.b64encode(image_data).decode('utf-8')
|
||||
os.remove(tmp)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -110,6 +151,7 @@ def circular_pred(df, extract_func=None):
|
||||
extract_func = lambda x: x # noqa: E731
|
||||
df = df.sort_values('index')
|
||||
from vlmeval.utils import can_infer_option
|
||||
|
||||
shift = int(1e6)
|
||||
|
||||
choices = [extract_func(x) for x in df['prediction']]
|
||||
@@ -118,9 +160,12 @@ def circular_pred(df, extract_func=None):
|
||||
valid_map = {i: True for i in pred_map if i < 1e6}
|
||||
for i in df['index']:
|
||||
if i >= shift and pred_map[i] and pred_map[i - shift]:
|
||||
if (
|
||||
pred_map[i] not in list(string.ascii_uppercase) or # noqa: W504
|
||||
pred_map[i - shift] not in list(string.ascii_uppercase)
|
||||
if pred_map[i] not in list(
|
||||
string.ascii_uppercase
|
||||
) or pred_map[ # noqa: W504
|
||||
i - shift
|
||||
] not in list(
|
||||
string.ascii_uppercase
|
||||
):
|
||||
|
||||
valid_map[i % shift] = False
|
||||
|
||||
Reference in New Issue
Block a user