Add eval_mm dir

This commit is contained in:
trainfanlhy
2024-05-28 01:21:34 +08:00
parent 7e12387362
commit 65f5567a3a
49 changed files with 5610 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .file import *
from .vlm import *
from .misc import *
from .log import *

View File

@@ -0,0 +1,230 @@
import json
import pickle
import pandas as pd
import os
import csv
import hashlib
import os.path as osp
import time
import numpy as np
import validators
import mimetypes
def LMUDataRoot():
if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
return os.environ['LMUData']
home = osp.expanduser('~')
root = osp.join(home, 'LMUData')
os.makedirs(root, exist_ok=True)
# root = './LMUData'
# os.makedirs(root, exist_ok=True)
return root
def MMBenchOfficialServer(dataset_name):
root = LMUDataRoot()
if dataset_name in ['MMBench', 'MMBench_V11', 'MMBench_CN', 'MMBench_CN_V11']:
ans_file = f'{root}/{dataset_name}.tsv'
if osp.exists(ans_file):
data = load(ans_file)
if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
return True
if dataset_name in ['MMBench_TEST_EN', 'MMBench_TEST_CN', 'MMBench_TEST_EN_V11', 'MMBench_TEST_CN_V11']:
ans_file1 = f'{root}/{dataset_name}.tsv'
mapp = {
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_CN': 'MMBench_CN',
'MMBench_TEST_EN_V11': 'MMBench_V11', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11',
}
ans_file2 = f'{root}/{mapp[dataset_name]}.tsv'
for f in [ans_file1, ans_file2]:
if osp.exists(f):
data = load(f)
if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
return True
return False
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
return {'real': obj.real, 'imag': obj.imag}
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()
elif isinstance(obj, (np.bool_)):
return bool(obj)
elif isinstance(obj, (np.void)):
return None
return json.JSONEncoder.default(self, obj)
# LOAD & DUMP
def dump(data, f, **kwargs):
def dump_pkl(data, pth, **kwargs):
pickle.dump(data, open(pth, 'wb'))
def dump_json(data, pth, **kwargs):
json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)
def dump_jsonl(data, f, **kwargs):
lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
with open(f, 'w', encoding='utf8') as fout:
fout.write('\n'.join(lines))
def dump_xlsx(data, f, **kwargs):
data.to_excel(f, index=False, engine='xlsxwriter')
def dump_csv(data, f, quoting=csv.QUOTE_ALL):
data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
suffix = f.split('.')[-1]
return handlers[suffix](data, f, **kwargs)
def load(f):
def load_pkl(pth):
return pickle.load(open(pth, 'rb'))
def load_json(pth):
return json.load(open(pth, 'r', encoding='utf-8'))
def load_jsonl(f):
lines = open(f, encoding='utf-8').readlines()
lines = [x.strip() for x in lines]
if lines[-1] == '':
lines = lines[:-1]
data = [json.loads(x) for x in lines]
return data
def load_xlsx(f):
return pd.read_excel(f)
def load_csv(f):
return pd.read_csv(f)
def load_tsv(f):
return pd.read_csv(f, sep='\t')
handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
suffix = f.split('.')[-1]
return handlers[suffix](f)
def download_file(url, filename=None):
import urllib.request
from tqdm import tqdm
class DownloadProgressBar(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
if filename is None:
filename = url.split('/')[-1]
with DownloadProgressBar(unit='B', unit_scale=True,
miniters=1, desc=url.split('/')[-1]) as t:
urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
return filename
def ls(dirname='.', match=[], mode='all', level=1):
if isinstance(level, str):
assert '+' in level
level = int(level[:-1])
res = []
for i in range(1, level + 1):
res.extend(ls(dirname, match=match, mode='file', level=i))
return res
if dirname == '.':
ans = os.listdir(dirname)
else:
ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
assert mode in ['all', 'dir', 'file']
assert level >= 1 and isinstance(level, int)
if level == 1:
if isinstance(match, str):
match = [match]
for m in match:
if len(m) == 0:
continue
if m[0] != '!':
ans = [x for x in ans if m in x]
else:
ans = [x for x in ans if m[1:] not in x]
if mode == 'dir':
ans = [x for x in ans if osp.isdir(x)]
elif mode == 'file':
ans = [x for x in ans if not osp.isdir(x)]
return ans
else:
dirs = [x for x in ans if osp.isdir(x)]
res = []
for d in dirs:
res.extend(ls(d, match=match, mode=mode, level=level - 1))
return res
def mrlines(fname, sp='\n'):
f = open(fname).read().split(sp)
while f != [] and f[-1] == '':
f = f[:-1]
return f
def mwlines(lines, fname):
with open(fname, 'w') as fout:
fout.write('\n'.join(lines))
def md5(s):
hash = hashlib.new('md5')
if osp.exists(s):
with open(s, 'rb') as f:
for chunk in iter(lambda: f.read(2**20), b''):
hash.update(chunk)
else:
hash.update(s.encode('utf-8'))
return str(hash.hexdigest())
def last_modified(pth):
stamp = osp.getmtime(pth)
m_ti = time.ctime(stamp)
t_obj = time.strptime(m_ti)
t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:]
return t
def parse_file(s):
if osp.exists(s):
assert osp.isfile(s)
suffix = osp.splitext(s)[1].lower()
mime = mimetypes.types_map.get(suffix, 'unknown')
return (mime, s)
elif validators.url(s):
suffix = osp.splitext(s)[1].lower()
if suffix in mimetypes.types_map:
mime = mimetypes.types_map[suffix]
dname = osp.join(LMUDataRoot(), 'files')
os.makedirs(dname, exist_ok=True)
tgt = osp.join(dname, md5(s) + suffix)
download_file(s, tgt)
return (mime, tgt)
else:
return ('url', s)
else:
return (None, s)

View File

@@ -0,0 +1,44 @@
import logging
logger_initialized = {}
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
except ImportError:
rank = 0
if rank == 0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger

View File

@@ -0,0 +1,191 @@
# flake8: noqa: F401, F403
import abc
import argparse
import csv
import multiprocessing as mp
import os
import os.path as osp
import copy as cp
import random as rd
import requests
import shutil
import subprocess
import warnings
import logging
import pandas as pd
from collections import OrderedDict, defaultdict
from multiprocessing import Pool, current_process
from tqdm import tqdm
import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate_formats, tabulate
from huggingface_hub import scan_cache_dir
from sty import fg, bg, ef, rs
def process_punctuation(inText):
import re
outText = inText
punct = [
';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
'>', '<', '@', '`', ',', '?', '!'
]
commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
for p in punct:
if (p + ' ' in inText or ' ' + p in inText) or (re.search(
commaStrip, inText) is not None):
outText = outText.replace(p, '')
else:
outText = outText.replace(p, ' ')
outText = periodStrip.sub('', outText, re.UNICODE)
return outText
def h2r(value):
if value[0] == '#':
value = value[1:]
assert len(value) == 6
return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
def r2h(rgb):
return '#%02x%02x%02x' % rgb
def colored(s, color):
if isinstance(color, str):
if hasattr(fg, color):
return getattr(fg, color) + s + fg.rs
color = h2r(color)
return fg(*color) + s + fg.rs
def istype(s, type):
if isinstance(s, type):
return True
try:
return isinstance(eval(s), type)
except Exception as _:
return False
def bincount(lst):
bins = defaultdict(lambda: 0)
for item in lst:
bins[item] += 1
return bins
def get_cache_path(repo_id):
hf_cache_info = scan_cache_dir()
repos = list(hf_cache_info.repos)
repo = None
for r in repos:
if r.repo_id == repo_id:
repo = r
break
if repo is None:
return None
revs = list(repo.revisions)
rev2keep, last_modified = None, 0
for rev in revs:
if rev.last_modified > last_modified:
rev2keep, last_modified = rev, rev.last_modified
if rev2keep is None:
return None
return str(rev2keep.snapshot_path)
def proxy_set(s):
import os
for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
os.environ[key] = s
def get_rank_and_world_size():
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))
return rank, world_size
def splitlen(s, sym='/'):
return len(s.split(sym))
def listinstr(lst, s):
assert isinstance(lst, list)
for item in lst:
if item in s:
return True
return False
def d2df(D):
return pd.DataFrame({x: [D[x]] for x in D})
def cn_string(s):
import re
if re.search(u'[\u4e00-\u9fff]', s):
return True
return False
try:
import decord
except ImportError:
pass
def timestr(second=True, minute=False):
s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]
if second:
return s
elif minute:
return s[:-2]
else:
return s[:-4]
def dict_merge(dct, merge_dct):
for k, _ in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
def youtube_dl(idx):
cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
os.system(cmd)
def run_command(cmd):
if isinstance(cmd, str):
cmd = cmd.split()
return subprocess.check_output(cmd).decode()
def load_env():
logger = logging.getLogger('LOAD_ENV')
try:
import vlmeval
except ImportError:
logger.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
return
pth = osp.realpath(vlmeval.__path__[0])
pth = osp.join(pth, '../.env')
pth = osp.realpath(pth)
if not osp.exists(pth):
logger.error(f'Did not detect the .env file at {pth}, failed to load. ')
return
from dotenv import dotenv_values
values = dotenv_values(pth)
for k, v in values.items():
if v is not None and len(v):
os.environ[k] = v
logger.info(f'API Keys successfully loaded from {pth}')
def pip_install_robust(package):
import sys
retry = 3
while retry > 0:
try:
package_base = package.split('=')[0]
module = __import__(package)
return True
except ImportError:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
retry -= 1
return False
def version_cmp(v1, v2, op='eq'):
from packaging import version
import operator
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))

View File

@@ -0,0 +1,134 @@
import os
import io
import pandas as pd
import numpy as np
import string
from uuid import uuid4
import os.path as osp
import base64
from PIL import Image
from .file import load, dump
Image.MAX_IMAGE_PIXELS = 1e9
def mmqa_display(question, target_size=512):
question = {k.lower(): v for k, v in question.items()}
keys = list(question.keys())
keys = [k for k in keys if k not in ['index', 'image']]
images = question['image']
if isinstance(images, str):
images = [images]
idx = question.pop('index', 'XXX')
print(f'INDEX: {idx}')
for im in images:
image = decode_base64_to_image(im, target_size=target_size)
display(image) # noqa: F821
for k in keys:
try:
if not pd.isna(question[k]):
print(f'{k.upper()}. {question[k]}')
except ValueError:
if False in pd.isna(question[k]):
print(f'{k.upper()}. {question[k]}')
def encode_image_to_base64(img, target_size=-1):
# if target_size == -1, will not do resizing
# else, will set the max_size ot (target_size, target_size)
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
tmp = osp.join('/tmp', str(uuid4()) + '.jpg')
if target_size > 0:
img.thumbnail((target_size, target_size))
img.save(tmp)
with open(tmp, 'rb') as image_file:
image_data = image_file.read()
ret = base64.b64encode(image_data).decode('utf-8')
os.remove(tmp)
return ret
def encode_image_file_to_base64(image_path, target_size=-1):
image = Image.open(image_path)
return encode_image_to_base64(image, target_size=target_size)
def decode_base64_to_image(base64_string, target_size=-1):
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
if image.mode in ('RGBA', 'P'):
image = image.convert('RGB')
if target_size > 0:
image.thumbnail((target_size, target_size))
return image
def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
image = decode_base64_to_image(base64_string, target_size=target_size)
image.save(image_path)
def build_option_str(option_dict):
s = 'There are several options: \n'
for c, content in option_dict.items():
if not pd.isna(content):
s += f'{c}. {content}\n'
return s
def isimg(s):
return osp.exists(s) or s.startswith('http')
def read_ok(img_path):
if not osp.exists(img_path):
return False
try:
im = Image.open(img_path)
assert im.size[0] > 0 and im.size[1] > 0
return True
except:
return False
def gpt_key_set():
openai_key = os.environ.get('OPENAI_API_KEY', None)
return isinstance(openai_key, str) and openai_key.startswith('sk-')
def apiok(wrapper):
s = wrapper.generate('Hello!')
return wrapper.fail_msg not in s
def circular_pred(df, extract_func=None):
if extract_func is None:
extract_func = lambda x: x # noqa: E731
df = df.sort_values('index')
from vlmeval.utils import can_infer_option
shift = int(1e6)
choices = [extract_func(x) for x in df['prediction']]
pred_map = {i: c for i, c in zip(df['index'], choices)}
flag_map = {i: True for i in pred_map if i < 1e6}
valid_map = {i: True for i in pred_map if i < 1e6}
for i in df['index']:
if i >= shift and pred_map[i] and pred_map[i - shift]:
if (
pred_map[i] not in list(string.ascii_uppercase) or # noqa: W504
pred_map[i - shift] not in list(string.ascii_uppercase)
):
valid_map[i % shift] = False
continue
if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
continue
else:
flag_map[i % shift] = False
flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
flags = list(flag_map.values())
return np.mean(flags)