Modify eval_mm for MiniCPM-o 2.6

This commit is contained in:
Poppy Xu
2025-01-21 15:34:54 +08:00
parent ec68cefc17
commit d8f382e157
82 changed files with 14279 additions and 843 deletions

View File

@@ -3,7 +3,7 @@ import warnings
from .image_base import ImageBaseDataset
from .utils import build_judge, DEBUG_MESSAGE
from ..smp import *
import pandas as pd
MMMB_URLS = {
'MMMB_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ar.tsv',
@@ -42,31 +42,31 @@ class ImageMCQDataset(ImageBaseDataset):
DATASET_URL = {
# MMBench v1.0
'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv',
'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv',
'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv',
'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv',
'MMBench': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv', # Internal Only
'MMBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv', # Internal Only
'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN.tsv',
'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_EN.tsv',
'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_CN.tsv',
'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_CN.tsv',
'MMBench': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench.tsv', # Internal
'MMBench_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_CN.tsv', # Internal
# MMBench v1.1
'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN_V11.tsv',
'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN_V11.tsv',
'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN_V11.tsv',
'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN_V11.tsv',
'MMBench_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_V11.tsv', # Internal Only
'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN_V11.tsv', # Internal Only
'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN_V11.tsv',
'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_EN_V11.tsv',
'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_CN_V11.tsv',
'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_CN_V11.tsv',
'MMBench_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_V11.tsv', # Internal
'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_CN_V11.tsv', # Internal
# SEEDBench Series
'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv',
'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench_IMG.tsv',
'SEEDBench2': 'https://huggingface.co/datasets/VLMEval/SEEDBench2/resolve/main/SEEDBench2.tsv',
'SEEDBench2_Plus': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench2_Plus.tsv',
'SEEDBench2_Plus': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench2_Plus.tsv',
# ScienceQA Series
'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv',
'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv',
'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_VAL.tsv',
'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_TEST.tsv',
# MMT-Bench
'MMT-Bench_ALL_MI': 'https://opencompass.openxlab.space/utils/VLMEval/MMT-Bench_ALL_MI.tsv',
'MMT-Bench_ALL': 'https://opencompass.openxlab.space/utils/VLMEval/MMT-Bench_ALL.tsv',
'MMT-Bench_VAL_MI': 'https://opencompass.openxlab.space/utils/VLMEval/MMT-Bench_VAL_MI.tsv',
'MMT-Bench_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMT-Bench_VAL.tsv',
'MMT-Bench_ALL_MI': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_ALL_MI.tsv',
'MMT-Bench_ALL': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_ALL.tsv',
'MMT-Bench_VAL_MI': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_VAL_MI.tsv',
'MMT-Bench_VAL': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_VAL.tsv',
# AesBench
'AesBench_VAL': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_VAL.tsv',
'AesBench_TEST': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_TEST.tsv',
@@ -76,6 +76,9 @@ class ImageMCQDataset(ImageBaseDataset):
# A-Bench
'A-Bench_VAL': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_VAL.tsv',
'A-Bench_TEST': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_TEST.tsv',
# R-Bench
'R-Bench-Dis': 'https://huggingface.co/datasets/lcysyzxdxc/R-Bench/blob/main/R-bench-dis.tsv',
'R-Bench-Ref': 'https://huggingface.co/datasets/lcysyzxdxc/R-Bench/blob/main/R-bench-ref.tsv',
# Other Benchmarks
'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
@@ -88,7 +91,16 @@ class ImageMCQDataset(ImageBaseDataset):
'https://huggingface.co/datasets/weikaih/TaskMeAnything-v1-imageqa-random/'
'resolve/main/TaskMeAnything-v1-imageqa-random.tsv'
),
'A-OKVQA': 'https://huggingface.co/datasets/Allen8/A-OKVQA/resolve/main/a-okvqa.tsv'
'A-OKVQA': 'https://huggingface.co/datasets/Allen8/A-OKVQA/resolve/main/a-okvqa.tsv',
'WorldMedQA-V': 'https://opencompass.openxlab.space/utils/VLMEval/WorldMedQA-V.tsv',
'VisOnlyQA-VLMEvalKit': (
'https://huggingface.co/datasets/ryokamoi/VisOnlyQA_Eval_Real/'
'resolve/main/visonlyqa_vlmevalkit.tsv'
),
'3DSRBench': (
'https://huggingface.co/datasets/ccvl/3DSRBench/'
'resolve/main/3dsrbench_v1_vlmevalkit_circular.tsv'
),
}
DATASET_MD5 = {
@@ -127,15 +139,21 @@ class ImageMCQDataset(ImageBaseDataset):
# A-Bench
'A-Bench_VAL': '218563ec50d34bb336c814143a5bb9c1',
'A-Bench_TEST': '567013fb033a20cf23f51d8e865bd16c',
# R-Bench
'R-Bench-Dis': 'd6e961dbfc43350688af2560226830b4',
'R-Bench-Ref': '270c1cb555acb523f3fdb178ed57021d',
# Other Benchmarks
'CCBench': 'f5dde47f24dc5a6fb6e595b409b466ac',
'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
'AI2D_TEST_NO_MASK': 'fd8f463634d4fe9fbd23b876e8eea5be',
'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
'RealWorldQA': '92321028d2bc29040284b6674721e48f',
'RealWorldQA': '4de008f55dc4fd008ca9e15321dc44b7',
'MLLMGuard_DS': '975fc0dd7119386e198c37d71e274b3f',
'BLINK': '3b6649b6a662184ea046908e5506260e',
'TaskMeAnything_v1_imageqa_random': '023fef69e2ca21827afb77c5ec3bc889'
'TaskMeAnything_v1_imageqa_random': '023fef69e2ca21827afb77c5ec3bc889',
'WorldMedQA-V': '441e63875e30c87f5750528b57b41285',
"VisOnlyQA-VLMEvalKit": 'cf460a31d2acb8d3a7cecd0e69298bfa',
'3DSRBench': '13a99f33164dc1b9faf0e8b8b01fd6f2',
}
DATASET_URL.update(MMMB_URLS)
@@ -256,6 +274,12 @@ class ImageMCQDataset(ImageBaseDataset):
warnings.warn('Note that AesBench VAL is just a toy version of AesBench TEST. For full results, \
please evaluate on AesBench TEST. The AesBench TEST dataset is more than 20 times \
larger than the VAL dataset and the leaderboard results are based on AesBench TEST.')
if dataset == 'VisOnlyQA-VLMEvalKit':
warnings.warn('Note that the results on VisOnlyQA-VLMEvalKit are different from the results on \
the original VisOnlyQA. VisOnlyQA-VLMEvalKit does not include the \
chemistry__shape_multi split and uses a different evaluation prompt. Please \
explicitly specify the version of the dataset when you report results.')
return acc
@@ -267,7 +291,7 @@ class MMMUDataset(ImageMCQDataset):
}
DATASET_MD5 = {
'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
'MMMU_DEV_VAL': '585e8ad75e73f75dcad265dfd0417d64',
'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
}
@@ -381,13 +405,72 @@ class MUIRDataset(ImageMCQDataset):
class GMAIMMBenchDataset(ImageMCQDataset):
DATASET_URL = {
'GMAI-MMBench_VAL': 'https://huggingface.co/datasets/VLMEval/GMAI-MMBench/resolve/main/GMAI-MMBench_VAL.tsv'
'GMAI-MMBench_VAL': 'https://huggingface.co/datasets/VLMEval/GMAI-MMBench/resolve/main/GMAI-MMBench_VAL.tsv',
'GMAI_mm_bench_TEST_part_1': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_1.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_2': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_2.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_3': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_3.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_4': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_4.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_5': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_5.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_6': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_6.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_7': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_7.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_8': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_8.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_9': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_9.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_10': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_10.tsv', # noqa: E501
'GMAI_mm_bench_TEST_part_11': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_11.tsv', # noqa: E501
}
DATASET_MD5 = {
'GMAI-MMBench_VAL': '254bd581627866f1c499d3d6b4422324'
'GMAI-MMBench_VAL': '254bd581627866f1c499d3d6b4422324',
'GMAI_mm_bench_TEST_part_1': '900d735231230a63f4ed45665c078ef4',
'GMAI_mm_bench_TEST_part_2': '1b27ab621386945d7e4a765ad2d22b0e',
'GMAI_mm_bench_TEST_part_3': '44bdc2b6267dd505d529b8cad06f0fb2',
'GMAI_mm_bench_TEST_part_4': '5a04a04fcac9f1466709f242fdb80acb',
'GMAI_mm_bench_TEST_part_5': 'c70baf8909eda9af0ddeab275c721336',
'GMAI_mm_bench_TEST_part_6': '825abc39596b644dead9350d0cfa3b96',
'GMAI_mm_bench_TEST_part_7': 'defb8aed2fb77365a76b6b9abd6a2701',
'GMAI_mm_bench_TEST_part_8': 'ff490d60b85f2bb0abb67a435b298c65',
'GMAI_mm_bench_TEST_part_9': 'ff67c86f40da93b09139ac1d1ba5dc6b',
'GMAI_mm_bench_TEST_part_10': '3dae94627b9ac0fe00180d4780fbf6dc',
'GMAI_mm_bench_TEST_part_11': 'd08dc813f0eb6bbab63cae2a9d113c4b',
}
@classmethod
def supported_datasets(cls):
return ['GMAI-MMBench_VAL', 'GMAI-MMBench_TEST']
def load_data(self, dataset):
if dataset == 'GMAI-MMBench_VAL':
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
if file_size(data_path, 'GB') > 1:
local_path = data_path.replace('.tsv', '_local.tsv')
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL'):
from ..tools import LOCALIZE
LOCALIZE(data_path, local_path)
data_path = local_path
return load(data_path)
elif dataset == 'GMAI-MMBench_TEST':
dfs = []
for part_num in range(1, 12):
part_name = f'GMAI_mm_bench_TEST_part_{part_num}'
url = self.DATASET_URL[part_name]
file_md5 = self.DATASET_MD5.get(part_name)
tsv_path = osp.join(LMUDataRoot(), f'{part_name}.tsv')
if not osp.exists(tsv_path) or (file_md5 and md5(tsv_path) != file_md5):
download_file(url, filename=tsv_path)
local_path = tsv_path.replace('.tsv', '_local.tsv')
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL'):
from ..tools import LOCALIZE
LOCALIZE(tsv_path, local_path)
tsv_path = local_path
# 加载数据
df = load(tsv_path)
dfs.append(df)
# 合并所有数据
data = pd.concat(dfs, ignore_index=True)
return data
else:
raise ValueError(f"未知的数据集:{dataset}")
def report_acc_by_groups(self, df, group_column):
res = defaultdict(list)
@@ -402,7 +485,7 @@ class GMAIMMBenchDataset(ImageMCQDataset):
res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
if group_column not in df:
raise ValueError(f"Column '{group_column}' not found in dataframe.")
raise ValueError(f"Column '{group_column}' not found in dataframe.") # noqa: E713
abilities = list(set(df[group_column]))
abilities = ['None' if isinstance(ab, float) and pd.isna(ab) else ab for ab in abilities]
@@ -470,6 +553,277 @@ class GMAIMMBenchDataset(ImageMCQDataset):
return acc
class MMERealWorld(ImageMCQDataset):
TYPE = 'MMERealWorld'
DATASET_MD5 = {
'MME-RealWorld': '271c33ec814c39533c467ec6fb8a6f36',
'MME-RealWorld-Lite': '4c17057d7d3b6c4a0d4397c3dae0881c',
'MME-RealWorld-CN': 'daaa763d52a760a38606d5dedb3fe444',
}
SYS = {
'MME-RealWorld': (
'Select the best answer to the above multiple-choice question based on the image. '
'Respond with only the letter (A, B, C, D, or E) of the correct option. \n'
'The best answer is:'
),
'MME-RealWorld-Lite': (
'Select the best answer to the above multiple-choice question based on the image. '
'Respond with only the letter (A, B, C, D, or E) of the correct option. \n'
'The best answer is:'
),
'MME-RealWorld-CN': (
'根据图像选择上述多项选择题的最佳答案。只需回答正确选项的字母A, B, C, D 或 E\n'
'最佳答案为:'
),
}
@classmethod
def supported_datasets(cls):
return ['MME-RealWorld', 'MME-RealWorld-CN', 'MME-RealWorld-Lite',]
def load_data(
self, dataset="MME-RealWorld", repo_id="yifanzhang114/MME-RealWorld-Base64"
):
def check_integrity(pth):
data_file = osp.join(pth, f"{dataset}.tsv")
if not os.path.exists(data_file):
return False
if md5(data_file) != self.DATASET_MD5[dataset]:
return False
return True
def generate_tsv(pth):
tsv_file = os.path.join(pth, f"{dataset}.tsv")
if os.path.exists(tsv_file):
print(f"{tsv_file} already exists.")
return
json_dir = os.path.join(pth, dataset)
json_files = [f for f in os.listdir(json_dir) if f.endswith(".json")]
data_list = []
for json_file in json_files:
with open(os.path.join(json_dir, json_file), "r") as f:
data = json.load(f)
for item in tqdm(data):
choice_prompt = (
"The choices are listed below:\n"
if dataset in ["MME-RealWorld", "MME-RealWorld-Lite"]
else "选项如下所示:\n"
)
data_list.append(
{
"index": item["index"],
"image": item["image"],
"question": item["question"],
"multi-choice options": choice_prompt
+ "\n".join(item["multi-choice options"]),
"A": item["multi-choice options"][0][4:],
"B": item["multi-choice options"][1][4:],
"C": item["multi-choice options"][2][4:],
"D": item["multi-choice options"][3][4:],
"E": item["multi-choice options"][4][4:],
"answer": item["answer"],
"category": item["category"],
"l2-category": item["l2-category"],
}
)
df = pd.DataFrame(data_list)
df.to_csv(tsv_file, sep="\t", index=False)
print(f"TSV file saved to {tsv_file}")
# Check if dataset is cached and has integrity
if dataset == "MME-RealWorld-Lite":
url = 'https://huggingface.co/datasets/yifanzhang114/MME-RealWorld-Base64/resolve/main/mme_realworld_lite.tsv' # noqa: E501
file_md5 = (
self.DATASET_MD5[dataset] if dataset in self.DATASET_MD5 else None
)
datas = self.prepare_tsv(url, file_md5)
choice_prompt = "The choices are listed below:\n"
for index, item in datas.iterrows():
options = eval(item["multi-choice options"])
datas.loc[index, "multi-choice options"] = choice_prompt + "\n".join(
options
)
datas.loc[index, "A"] = options[0][4:]
datas.loc[index, "B"] = options[1][4:]
datas.loc[index, "C"] = options[2][4:]
datas.loc[index, "D"] = options[3][4:]
datas.loc[index, "E"] = options[4][4:]
return datas
update_flag = False
cache_path = get_cache_path(repo_id)
if cache_path is not None and check_integrity(cache_path):
dataset_path = cache_path
print(f"Using cached dataset from {cache_path}")
else:
from huggingface_hub import snapshot_download
# Download or find the dataset path
dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
generate_tsv(dataset_path)
update_flag = True
data_path = os.path.join(dataset_path, f"{dataset}.tsv")
if file_size(data_path, "GB") > 1:
local_path = data_path.replace(".tsv", "_local.tsv")
if (
not osp.exists(local_path)
or os.environ.get("FORCE_LOCAL", None)
or update_flag
):
from vlmeval.tools import LOCALIZE
LOCALIZE(data_path, local_path)
data_path = local_path
return load(data_path)
def post_build(self, dataset):
self.TYPE = 'MMERealWorld'
# Given one data record, return the built prompt (a multi-modal message), can override
def build_prompt(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
if self.meta_only:
tgt_path = toliststr(line['image_path'])
else:
tgt_path = self.dump_image(line)
question = line['question']
choice_prompt = line['multi-choice options'] + '\n'
question += ' ' + choice_prompt + self.SYS[self.dataset_name]
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=question))
return msgs
# It returns a dictionary
@classmethod
def evaluate(self, eval_file, **judge_kwargs):
from .utils.multiple_choice import extract_characters_regex, get_dimension_rating
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
FAIL_MSG = 'Failed to obtain answer via API.'
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
tgt_file = eval_file.replace('.xlsx', '_rating.json')
score_file = eval_file.replace('.xlsx', '_score.xlsx')
if not osp.exists(score_file):
res = {} if not osp.exists(tmp_file) else load(tmp_file)
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
data = load(eval_file)
cnt_rejected = 0
data_un = data[~pd.isna(data['prediction'])]
for idx in data['index']:
ans = data.loc[data['index'] == idx, 'answer'].values[0]
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
extract_pred = extract_characters_regex(pred)
if extract_pred == '':
cnt_rejected += 1
data.loc[data['index'] == idx, 'score'] = 0
else:
data.loc[data['index'] == idx, 'score'] = int(extract_pred == ans)
print(
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
f'failed to obtain the score for another {cnt_rejected} questions. '
f'Those questions will be counted as 0 score in ALL rating.'
)
dump(data, score_file)
rating = get_dimension_rating(score_file)
dump(rating, tgt_file)
return rating
class HRBenchDataset(ImageMCQDataset):
DATASET_URL = {
'HRBench4K': 'https://huggingface.co/datasets/DreamMr/HR-Bench/resolve/main/hr_bench_4k.tsv',
'HRBench8K': 'https://huggingface.co/datasets/DreamMr/HR-Bench/resolve/main/hr_bench_8k.tsv',
}
DATASET_MD5 = {
'HRBench4K': 'f6b041b03d49543494b8a56d2e35be65',
'HRBench8K': '274c9c7f89329b804a4723178a00219c',
}
def evaluate(self, eval_file, **judge_kwargs):
assert os.path.exists(eval_file), '{} does not exist!'.format(eval_file)
from .utils.multiple_choice import mcq_vanilla_eval
from .utils.hrbench import report_acc_hrbench
nproc = judge_kwargs.pop('nproc', 4)
suffix = eval_file.split('.')[-1]
model = judge_kwargs.get('model', 'extract_matching')
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
name_str = name_str_map[model] if model in name_str_map else model
if model == 'exact_matching':
model = None
elif gpt_key_set():
model = build_judge(**judge_kwargs)
if not model.working():
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
warnings.warn(DEBUG_MESSAGE)
model = None
else:
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
model = None
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
data = load(eval_file)
data = data.sort_values(by='index')
data['prediction'] = [str(x) for x in data['prediction']]
# If not choice label, then use lower case
for k in data.keys():
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
meta = self.data
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
data_map = {x: y for x, y in zip(data['index'], data['question'])}
for k in data_map:
assert k in meta_q_map, (
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
)
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
if osp.exists(score_file):
acc = load(score_file)
return acc
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
acc = report_acc_hrbench(data)
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
dump(acc, score_file)
return acc
class CustomMCQDataset(ImageMCQDataset):
def load_data(self, dataset):
@@ -482,3 +836,69 @@ class CustomMCQDataset(ImageMCQDataset):
LOCALIZE(data_path, local_path)
data_path = local_path
return load(data_path)
class NaturalBenchDataset(ImageMCQDataset):
DATASET_URL = {
'NaturalBenchDataset': (
'https://huggingface.co/datasets/BaiqiL/'
'NaturalBench/resolve/main/NaturalBenchDataset.tsv'
),
}
DATASET_MD5 = {
'NaturalBenchDataset':'dbe25b044bc35696426381e9ba4fe930',
}
def build_prompt(self, line):
SUFFIX_FOR_VQA = {
"yes_no": "Please answer Yes or No.",
"multiple_choice": "Please output the letter corresponding to the correct option."
}
if isinstance(line, int):
line = self.data.iloc[line]
if self.meta_only:
tgt_path = toliststr(line['image_path'])
else:
tgt_path = self.dump_image(line)
question = line['question']
prompt = f'{question} {SUFFIX_FOR_VQA[line["type"]]}'
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=prompt))
return msgs
def evaluate(self, eval_file, **judge_kwargs):
from .utils.naturalbench import extract_answer, get_scores
data = load(eval_file)
data = data.sort_values(by='index')
predictions = [str(x) for x in data['prediction']]
answers = [str(x) for x in data['answer']]
indexs = [str(x) for x in data['index']]
meta = self.data
types = [str(x) for x in meta['type']]
results = {}
assert len(predictions) == len(answers) == len(indexs) == len(types) == (1900 * 4)
number_answered_samples = len(predictions) // 4
for i in range(number_answered_samples):
results[i] = {
"q0_i0": extract_answer(predictions[i * 4], types[i * 4]),
"q0_i1": extract_answer(predictions[i * 4 + 1], types[i * 4 + 1]),
"q1_i0": extract_answer(predictions[i * 4 + 2], types[i * 4 + 2]),
"q1_i1": extract_answer(predictions[i * 4 + 3], types[i * 4 + 3])
}
scores = get_scores(results)
print(scores)
score_file = 'NaturalBench_acc.csv'
df = pd.DataFrame(list(scores.items()), columns=['Metric', 'Score'])
dump(df, score_file)
return scores