Modify eval_mm for MiniCPM-V 2.6

This commit is contained in:
Haoyu Li
2024-08-30 18:18:22 +00:00
parent ab1141ee45
commit 59224808a1
69 changed files with 8231 additions and 1818 deletions

View File

@@ -1,12 +1,7 @@
from .matching_util import can_infer, can_infer_option, can_infer_text
from .mp_util import track_progress_rich
from .custom_prompt import CustomPrompt
from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full
from .dataset import TSVDataset, split_MMMU, MMMU_result_transfer
__all__ = [
'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt',
'split_MMMU', 'abbr2full'
]

View File

@@ -1,33 +0,0 @@
from ..smp import *
from .dataset_config import img_root_map
from abc import abstractmethod
class CustomPrompt:
@abstractmethod
def use_custom_prompt(self, dataset):
raise NotImplementedError
@abstractmethod
def build_prompt(self, line, dataset):
raise NotImplementedError
def dump_image(self, line, dataset):
ROOT = LMUDataRoot()
assert isinstance(dataset, str)
img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
os.makedirs(img_root, exist_ok=True)
if isinstance(line['image'], list):
tgt_path = []
assert 'image_path' in line
for img, im_name in zip(line['image'], line['image_path']):
path = osp.join(img_root, im_name)
if not read_ok(path):
decode_base64_to_image_file(img, path)
tgt_path.append(path)
else:
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
if not read_ok(tgt_path):
decode_base64_to_image_file(line['image'], tgt_path)
return tgt_path

View File

@@ -1,190 +0,0 @@
import pandas as pd
import hashlib
from ..smp import *
from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE
from .custom_prompt import CustomPrompt
from .matching_util import can_infer
def isliststr(s):
return (s[0] == '[') and (s[-1] == ']')
def check_md5(data_path, dataset):
if dataset not in dataset_md5_dict:
warnings.warn(f'We do not have an md5 record for dataset {dataset}, skip the md5 check. ')
return True
assert osp.exists(data_path)
with open(data_path, 'rb') as f:
hash = hashlib.new('md5')
for chunk in iter(lambda: f.read(2**20), b''):
hash.update(chunk)
if str(hash.hexdigest()) == dataset_md5_dict[dataset]:
return True
else:
warnings.warn('this data file is incomplete, so it needs to be downloaded again.')
return False
def split_MMMU(msgs):
text, images = None, []
for s in msgs:
if s['type'] == 'image':
images.append(s['value'])
elif s['type'] == 'text':
assert text is None
text = s['value']
text_segs = text.split('<image ')
segs = [dict(type='text', value=text_segs[0])]
for i, seg in enumerate(text_segs):
if i == 0:
continue
assert istype(seg[0], int) and seg[1] == '>'
image_idx = int(seg[0]) - 1
segs.append(dict(type='image', value=images[image_idx]))
segs.append(dict(type='text', value=seg[2:]))
return segs
def MMMU_result_transfer(result_path):
res = {}
result_data = load(result_path)
mcq = result_data['A'].notna()
lt = len(result_data)
for i in range(lt):
line = result_data.iloc[i]
if mcq[i]:
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
prediction = line['prediction']
infer_prediction = can_infer(prediction, options)
res[line['id']] = infer_prediction
else:
res[line['id']] = line['prediction']
result_json = result_path.replace('.xlsx', '.json')
dump(res, result_json)
return result_json
class TSVDataset(CustomPrompt):
def __init__(self, dataset='MMBench', skip_noimg=True):
self.data_root = LMUDataRoot()
assert osp.exists(self.data_root)
self.dataset = dataset
self.dataset_type = DATASET_TYPE(dataset)
if dataset in dataset_URLs:
url = dataset_URLs[dataset]
file_name = url.split('/')[-1]
data_path = osp.join(self.data_root, file_name)
if osp.exists(data_path) and check_md5(data_path, dataset):
pass
elif osp.isfile(url):
# If url is actually a file path, use it directly
data_path = url
else:
warnings.warn('The dataset tsv is not downloaded')
download_file(url, data_path)
else:
data_path = osp.join(self.data_root, dataset + '.tsv')
assert osp.exists(data_path)
data = load(data_path)
self.skip_noimg = skip_noimg
if skip_noimg and 'image' in data:
data = data[~pd.isna(data['image'])]
# Prompt for Captioning
if listinstr(['COCO'], dataset):
data['question'] = [(
'Please describe this image in general. Directly provide the description, '
'do not include prefix like "This image depicts". '
)] * len(data)
data['index'] = [str(x) for x in data['index']]
self.meta_only = True
if 'image' in data:
data['image'] = [str(x) for x in data['image']]
image_map = {x: y for x, y in zip(data['index'], data['image'])}
for k in image_map:
if len(image_map[k]) <= 64:
idx = image_map[k]
assert idx in image_map and len(image_map[idx]) > 64
image_map[k] = image_map[idx]
data['image'] = [
eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
for k in data['index']
]
self.meta_only = False
if 'image_path' in data:
data['image_path'] = [
eval(pths) if isliststr(pths) else pths for pths in data['image_path']
]
if np.all([istype(x, int) for x in data['index']]):
data['index'] = [int(x) for x in data['index']]
self.data = data
def __len__(self):
return len(self.data)
def build_prompt(self, line, dataset=None):
if dataset is None:
dataset = self.dataset
if isinstance(line, int):
line = self.data.iloc[line]
if self.meta_only:
tgt_path = line['image_path']
else:
tgt_path = self.dump_image(line, dataset)
prompt = line['question']
if DATASET_TYPE(dataset) == 'multi-choice':
question = line['question']
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
prompt += 'Please select the correct answer from the options above. \n'
elif DATASET_TYPE(dataset) == 'VQA':
if listinstr(['ocrvqa', 'textvqa', 'chartqa', 'docvqa'], dataset.lower()):
prompt += '\nPlease try to answer the question with short words or phrases if possible\n.'
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=prompt))
return msgs
def display(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
mmqa_display(line)

View File

@@ -1,158 +0,0 @@
from ..smp import listinstr
dataset_URLs = {
# MMBench v1.0
'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv',
'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv',
'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv',
'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv',
'MMBench': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv', # Internal Only
'MMBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv', # Internal Only
# MMBench v1.1
'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN_V11.tsv',
'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN_V11.tsv',
'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN_V11.tsv',
'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN_V11.tsv',
'MMBench_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_V11.tsv', # Internal Only
'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN_V11.tsv', # Internal Only
# CCBench
'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv',
'CORE_MM': 'https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv',
'MMVet': 'https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv',
'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
'OCRVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv',
'OCRVQA_TESTCORE': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv',
'TextVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv',
'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
'MathVista_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv',
'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv',
'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv',
'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
'DocVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv',
'DocVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_TEST.tsv',
'InfoVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_VAL.tsv',
'InfoVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_TEST.tsv',
'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv',
'OCRBench': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
'ChartQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ChartQA_TEST.tsv',
'MMStar': 'https://opencompass.openxlab.space/utils/VLMEval/MMStar.tsv',
'RealWorldQA': 'https://opencompass.openxlab.space/utils/VLMEval/RealWorldQA.tsv',
'POPE': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
}
dataset_md5_dict = {
# MMBench v1.0
'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Internal Only
'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Internal Only
# MMBench v1.1
'MMBench_DEV_EN_V11': '30c05be8f2f347a50be25aa067248184',
'MMBench_TEST_EN_V11': '26f0f15381a21720255091d3e0316ce6',
'MMBench_DEV_CN_V11': '593f9b5f6bea453d870a798b34ae4f37',
'MMBench_TEST_CN_V11': '74bbe4556dac745613c7cbe5ad787050',
'MMBench_V11': 'b9276414f57af1308dcc4d0cd9b42e7c', # Internal Only
'MMBench_CN_V11': '95f6980dd1b4de38e3cbffe0305a3f25', # Internal Only
# CCBench
'CCBench': '1de88b4257e7eee3f60b18d45eda6f07',
'MME': 'b36b43c3f09801f5d368627fb92187c3',
'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
'CORE_MM': '8a8da2f2232e79caf98415bfdf0a202d',
'MMVet': '748aa6d4aa9d4de798306a63718455e3',
'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464',
'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
'DocVQA_VAL': 'd5ee77e1926ff10690d469c56b73eabf',
'DocVQA_TEST': '6a2f28cac26ef2d3447374e8c6f6c8e9',
'InfoVQA_VAL': '2342e9c225222f0ef4dec545ebb126fe',
'InfoVQA_TEST': 'df535bf51b88dc9718252c34131a6227',
'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
'LLaVABench': 'd382a093f749a697820d3dadd61c8428',
'OCRBench': 'e953d98a987cc6e26ef717b61260b778',
'ChartQA_TEST': 'c902e0aa9be5582a7aad6dcf52734b42',
'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
'RealWorldQA': '92321028d2bc29040284b6674721e48f',
'POPE': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
}
img_root_map = {k: k for k in dataset_URLs}
img_root_map.update({
# MMBench v1.0
'MMBench_DEV_EN': 'MMBench',
'MMBench_TEST_EN': 'MMBench',
'MMBench_DEV_CN': 'MMBench',
'MMBench_TEST_CN': 'MMBench',
'MMBench': 'MMBench', # Internal Only
'MMBench_CN': 'MMBench', # Internal Only
# MMBench v1.1
'MMBench_DEV_EN_V11': 'MMBench_V11',
'MMBench_TEST_EN_V11': 'MMBench_V11',
'MMBench_DEV_CN_V11': 'MMBench_V11',
'MMBench_TEST_CN_V11': 'MMBench_V11',
'MMBench_V11': 'MMBench_V11', # Internal Only
'MMBench_CN_V11': 'MMBench_V11', # Internal Only
'COCO_VAL': 'COCO',
'OCRVQA_TEST': 'OCRVQA',
'OCRVQA_TESTCORE': 'OCRVQA',
'TextVQA_VAL': 'TextVQA',
'MMMU_DEV_VAL': 'MMMU',
'MMMU_TEST': 'MMMU',
'MathVista_MINI': 'MathVista',
'HallusionBench': 'Hallusion',
'DocVQA_VAL': 'DocVQA',
'DocVQA_TEST': 'DocVQA_TEST',
'OCRBench': 'OCRBench',
'ChartQA_TEST': 'ChartQA_TEST',
'InfoVQA_VAL': 'InfoVQA_VAL',
'InfoVQA_TEST': 'InfoVQA_TEST',
'MMStar': 'MMStar',
'RealWorldQA': 'RealWorldQA',
'POPE': 'POPE',
})
assert set(dataset_URLs) == set(img_root_map)
def DATASET_TYPE(dataset):
# Dealing with Custom Dataset
dataset = dataset.lower()
if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d', 'mmstar', 'realworldqa'], dataset):
return 'multi-choice'
elif listinstr(['mme', 'hallusion', 'pope'], dataset):
return 'Y/N'
elif 'coco' in dataset:
return 'Caption'
elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'infovqa', 'llavabench',
'mmvet', 'ocrbench'], dataset):
return 'VQA'
else:
if dataset not in dataset_URLs:
import warnings
warnings.warn(f"Dataset {dataset} not found in dataset_URLs, will use 'multi-choice' as the default TYPE.")
return 'multi-choice'
else:
return 'QA'
def abbr2full(s):
datasets = [x for x in img_root_map]
ins = [s in d for d in datasets]
if sum(ins) == 1:
for d in datasets:
if s in d:
return d
else:
return s

View File

@@ -145,7 +145,7 @@ def track_progress_rich(func: Callable,
results = []
for task in tasks:
result, idx = worker(task)
results.append(worker(task)[0])
results.append(result)
if save is not None:
with portalocker.Lock(save, timeout=5) as fh:
ans = load(save)

View File

@@ -0,0 +1,97 @@
from ..smp import *
from ..dataset.utils.judge_util import build_judge
from ..dataset.utils.multiple_choice import extract_answer_from_item
from .matching_util import can_infer
from .mp_util import track_progress_rich
def MMMU_result_transfer(result_path):
res = {}
result_data = load(result_path)
mcq = result_data['A'].notna()
lt = len(result_data)
for i in range(lt):
line = result_data.iloc[i]
if mcq[i]:
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
prediction = line['prediction']
infer_prediction = can_infer(prediction, options)
res[line['id']] = infer_prediction
else:
res[line['id']] = line['prediction']
result_json = result_path.replace('.xlsx', '.json')
dump(res, result_json)
return result_json
def MMTBench_result_transfer(eval_file, dataset='default', **judge_kwargs):
logger = get_logger('Evaluation')
nproc = judge_kwargs.pop('nproc', 4)
rd.seed(2680)
suffix = eval_file.split('.')[-1]
model = judge_kwargs['model']
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
name_str_map = {
'chatgpt-0125': 'openai',
'gpt-4-0125': 'gpt4'
}
name_str = name_str_map[model] if model in name_str_map else model
if model == 'exact_matching':
model = None
elif gpt_key_set():
model = build_judge(**judge_kwargs)
if not model.working():
logger.error('The OPENAI API is not working properly, will use exact matching for evaluation')
model = None
else:
logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
model = None
logger.info(f'Evaluating {eval_file}')
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_option.pkl')
result = {}
if osp.exists(result_file):
result = load(result_file)
data = load(eval_file)
assert 'index' in data, 'Essentail columns missing in the eval_file.'
data = data.sort_values(by='index')
data['prediction'] = [str(x) for x in data['prediction']]
for k in data.keys():
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
idx2lines = {data.iloc[i]['index']: data.iloc[i] for i in range(len(data))}
idx2lines = {k: v for k, v in idx2lines.items() if k not in result}
indices = list(idx2lines.keys())
lines = [idx2lines[i] for i in indices]
tups = [(model, line) for line in lines]
res = track_progress_rich(
extract_answer_from_item,
tups,
nproc=nproc,
chunksize=nproc,
save=result_file,
keys=indices)
for i, r in zip(indices, res):
if i in result:
assert result[i]['opt'] == r['opt'] and result[i]['log'] == r['log']
else:
result[i] = r
indices = list(data['index'])
data['opt'] = [result[i]['opt'] for i in data['index']]
data['log'] = [result[i]['log'] for i in data['index']]
# load split
output_path = eval_file.replace(f'.{suffix}', f'_{name_str}_submission.tsv')
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_submission.tsv'))
return output_path