mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-V 2.6
This commit is contained in:
@@ -1,12 +1,7 @@
|
||||
from .matching_util import can_infer, can_infer_option, can_infer_text
|
||||
from .mp_util import track_progress_rich
|
||||
from .custom_prompt import CustomPrompt
|
||||
from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full
|
||||
from .dataset import TSVDataset, split_MMMU, MMMU_result_transfer
|
||||
|
||||
|
||||
__all__ = [
|
||||
'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
|
||||
'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt',
|
||||
'split_MMMU', 'abbr2full'
|
||||
]
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from ..smp import *
|
||||
from .dataset_config import img_root_map
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class CustomPrompt:
|
||||
|
||||
@abstractmethod
|
||||
def use_custom_prompt(self, dataset):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build_prompt(self, line, dataset):
|
||||
raise NotImplementedError
|
||||
|
||||
def dump_image(self, line, dataset):
|
||||
ROOT = LMUDataRoot()
|
||||
assert isinstance(dataset, str)
|
||||
img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
|
||||
os.makedirs(img_root, exist_ok=True)
|
||||
if isinstance(line['image'], list):
|
||||
tgt_path = []
|
||||
assert 'image_path' in line
|
||||
for img, im_name in zip(line['image'], line['image_path']):
|
||||
path = osp.join(img_root, im_name)
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(img, path)
|
||||
tgt_path.append(path)
|
||||
else:
|
||||
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
|
||||
if not read_ok(tgt_path):
|
||||
decode_base64_to_image_file(line['image'], tgt_path)
|
||||
return tgt_path
|
||||
@@ -1,190 +0,0 @@
|
||||
import pandas as pd
|
||||
import hashlib
|
||||
from ..smp import *
|
||||
from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE
|
||||
from .custom_prompt import CustomPrompt
|
||||
from .matching_util import can_infer
|
||||
|
||||
|
||||
def isliststr(s):
|
||||
return (s[0] == '[') and (s[-1] == ']')
|
||||
|
||||
|
||||
def check_md5(data_path, dataset):
|
||||
if dataset not in dataset_md5_dict:
|
||||
warnings.warn(f'We do not have an md5 record for dataset {dataset}, skip the md5 check. ')
|
||||
return True
|
||||
assert osp.exists(data_path)
|
||||
with open(data_path, 'rb') as f:
|
||||
hash = hashlib.new('md5')
|
||||
for chunk in iter(lambda: f.read(2**20), b''):
|
||||
hash.update(chunk)
|
||||
if str(hash.hexdigest()) == dataset_md5_dict[dataset]:
|
||||
return True
|
||||
else:
|
||||
warnings.warn('this data file is incomplete, so it needs to be downloaded again.')
|
||||
return False
|
||||
|
||||
|
||||
def split_MMMU(msgs):
|
||||
text, images = None, []
|
||||
for s in msgs:
|
||||
if s['type'] == 'image':
|
||||
images.append(s['value'])
|
||||
elif s['type'] == 'text':
|
||||
assert text is None
|
||||
text = s['value']
|
||||
text_segs = text.split('<image ')
|
||||
segs = [dict(type='text', value=text_segs[0])]
|
||||
for i, seg in enumerate(text_segs):
|
||||
if i == 0:
|
||||
continue
|
||||
assert istype(seg[0], int) and seg[1] == '>'
|
||||
image_idx = int(seg[0]) - 1
|
||||
segs.append(dict(type='image', value=images[image_idx]))
|
||||
segs.append(dict(type='text', value=seg[2:]))
|
||||
return segs
|
||||
|
||||
|
||||
def MMMU_result_transfer(result_path):
|
||||
res = {}
|
||||
result_data = load(result_path)
|
||||
mcq = result_data['A'].notna()
|
||||
lt = len(result_data)
|
||||
for i in range(lt):
|
||||
line = result_data.iloc[i]
|
||||
if mcq[i]:
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
prediction = line['prediction']
|
||||
infer_prediction = can_infer(prediction, options)
|
||||
res[line['id']] = infer_prediction
|
||||
else:
|
||||
res[line['id']] = line['prediction']
|
||||
result_json = result_path.replace('.xlsx', '.json')
|
||||
dump(res, result_json)
|
||||
return result_json
|
||||
|
||||
|
||||
class TSVDataset(CustomPrompt):
|
||||
|
||||
def __init__(self, dataset='MMBench', skip_noimg=True):
|
||||
|
||||
self.data_root = LMUDataRoot()
|
||||
assert osp.exists(self.data_root)
|
||||
|
||||
self.dataset = dataset
|
||||
self.dataset_type = DATASET_TYPE(dataset)
|
||||
|
||||
if dataset in dataset_URLs:
|
||||
url = dataset_URLs[dataset]
|
||||
file_name = url.split('/')[-1]
|
||||
data_path = osp.join(self.data_root, file_name)
|
||||
|
||||
if osp.exists(data_path) and check_md5(data_path, dataset):
|
||||
pass
|
||||
elif osp.isfile(url):
|
||||
# If url is actually a file path, use it directly
|
||||
data_path = url
|
||||
else:
|
||||
warnings.warn('The dataset tsv is not downloaded')
|
||||
download_file(url, data_path)
|
||||
else:
|
||||
data_path = osp.join(self.data_root, dataset + '.tsv')
|
||||
assert osp.exists(data_path)
|
||||
|
||||
data = load(data_path)
|
||||
self.skip_noimg = skip_noimg
|
||||
if skip_noimg and 'image' in data:
|
||||
data = data[~pd.isna(data['image'])]
|
||||
|
||||
# Prompt for Captioning
|
||||
if listinstr(['COCO'], dataset):
|
||||
data['question'] = [(
|
||||
'Please describe this image in general. Directly provide the description, '
|
||||
'do not include prefix like "This image depicts". '
|
||||
)] * len(data)
|
||||
|
||||
data['index'] = [str(x) for x in data['index']]
|
||||
|
||||
self.meta_only = True
|
||||
if 'image' in data:
|
||||
data['image'] = [str(x) for x in data['image']]
|
||||
|
||||
image_map = {x: y for x, y in zip(data['index'], data['image'])}
|
||||
for k in image_map:
|
||||
if len(image_map[k]) <= 64:
|
||||
idx = image_map[k]
|
||||
assert idx in image_map and len(image_map[idx]) > 64
|
||||
image_map[k] = image_map[idx]
|
||||
|
||||
data['image'] = [
|
||||
eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
|
||||
for k in data['index']
|
||||
]
|
||||
self.meta_only = False
|
||||
|
||||
if 'image_path' in data:
|
||||
data['image_path'] = [
|
||||
eval(pths) if isliststr(pths) else pths for pths in data['image_path']
|
||||
]
|
||||
|
||||
if np.all([istype(x, int) for x in data['index']]):
|
||||
data['index'] = [int(x) for x in data['index']]
|
||||
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def build_prompt(self, line, dataset=None):
|
||||
if dataset is None:
|
||||
dataset = self.dataset
|
||||
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = line['image_path']
|
||||
else:
|
||||
tgt_path = self.dump_image(line, dataset)
|
||||
|
||||
prompt = line['question']
|
||||
if DATASET_TYPE(dataset) == 'multi-choice':
|
||||
question = line['question']
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
options_prompt = 'Options:\n'
|
||||
for key, item in options.items():
|
||||
options_prompt += f'{key}. {item}\n'
|
||||
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
||||
prompt = ''
|
||||
if hint is not None:
|
||||
prompt += f'Hint: {hint}\n'
|
||||
prompt += f'Question: {question}\n'
|
||||
if len(options):
|
||||
prompt += options_prompt
|
||||
prompt += 'Please select the correct answer from the options above. \n'
|
||||
elif DATASET_TYPE(dataset) == 'VQA':
|
||||
if listinstr(['ocrvqa', 'textvqa', 'chartqa', 'docvqa'], dataset.lower()):
|
||||
prompt += '\nPlease try to answer the question with short words or phrases if possible\n.'
|
||||
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
|
||||
return msgs
|
||||
|
||||
def display(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
mmqa_display(line)
|
||||
@@ -1,158 +0,0 @@
|
||||
from ..smp import listinstr
|
||||
|
||||
dataset_URLs = {
|
||||
# MMBench v1.0
|
||||
'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv',
|
||||
'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv',
|
||||
'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv',
|
||||
'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv',
|
||||
'MMBench': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv', # Internal Only
|
||||
'MMBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv', # Internal Only
|
||||
# MMBench v1.1
|
||||
'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN_V11.tsv',
|
||||
'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN_V11.tsv',
|
||||
'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN_V11.tsv',
|
||||
'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN_V11.tsv',
|
||||
'MMBench_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_V11.tsv', # Internal Only
|
||||
'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN_V11.tsv', # Internal Only
|
||||
# CCBench
|
||||
'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
|
||||
'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
|
||||
'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv',
|
||||
'CORE_MM': 'https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv',
|
||||
'MMVet': 'https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv',
|
||||
'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
|
||||
'OCRVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv',
|
||||
'OCRVQA_TESTCORE': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv',
|
||||
'TextVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv',
|
||||
'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
|
||||
'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
|
||||
'MathVista_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv',
|
||||
'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv',
|
||||
'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv',
|
||||
'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
|
||||
'DocVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv',
|
||||
'DocVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_TEST.tsv',
|
||||
'InfoVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_VAL.tsv',
|
||||
'InfoVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_TEST.tsv',
|
||||
'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
|
||||
'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv',
|
||||
'OCRBench': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
|
||||
'ChartQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ChartQA_TEST.tsv',
|
||||
'MMStar': 'https://opencompass.openxlab.space/utils/VLMEval/MMStar.tsv',
|
||||
'RealWorldQA': 'https://opencompass.openxlab.space/utils/VLMEval/RealWorldQA.tsv',
|
||||
'POPE': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
|
||||
}
|
||||
|
||||
dataset_md5_dict = {
|
||||
# MMBench v1.0
|
||||
'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
|
||||
'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
|
||||
'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
|
||||
'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
|
||||
'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Internal Only
|
||||
'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Internal Only
|
||||
# MMBench v1.1
|
||||
'MMBench_DEV_EN_V11': '30c05be8f2f347a50be25aa067248184',
|
||||
'MMBench_TEST_EN_V11': '26f0f15381a21720255091d3e0316ce6',
|
||||
'MMBench_DEV_CN_V11': '593f9b5f6bea453d870a798b34ae4f37',
|
||||
'MMBench_TEST_CN_V11': '74bbe4556dac745613c7cbe5ad787050',
|
||||
'MMBench_V11': 'b9276414f57af1308dcc4d0cd9b42e7c', # Internal Only
|
||||
'MMBench_CN_V11': '95f6980dd1b4de38e3cbffe0305a3f25', # Internal Only
|
||||
# CCBench
|
||||
'CCBench': '1de88b4257e7eee3f60b18d45eda6f07',
|
||||
'MME': 'b36b43c3f09801f5d368627fb92187c3',
|
||||
'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
|
||||
'CORE_MM': '8a8da2f2232e79caf98415bfdf0a202d',
|
||||
'MMVet': '748aa6d4aa9d4de798306a63718455e3',
|
||||
'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
|
||||
'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
|
||||
'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
|
||||
'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
|
||||
'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
|
||||
'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
|
||||
'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464',
|
||||
'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
|
||||
'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
|
||||
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
|
||||
'DocVQA_VAL': 'd5ee77e1926ff10690d469c56b73eabf',
|
||||
'DocVQA_TEST': '6a2f28cac26ef2d3447374e8c6f6c8e9',
|
||||
'InfoVQA_VAL': '2342e9c225222f0ef4dec545ebb126fe',
|
||||
'InfoVQA_TEST': 'df535bf51b88dc9718252c34131a6227',
|
||||
'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
|
||||
'LLaVABench': 'd382a093f749a697820d3dadd61c8428',
|
||||
'OCRBench': 'e953d98a987cc6e26ef717b61260b778',
|
||||
'ChartQA_TEST': 'c902e0aa9be5582a7aad6dcf52734b42',
|
||||
'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
|
||||
'RealWorldQA': '92321028d2bc29040284b6674721e48f',
|
||||
'POPE': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
|
||||
}
|
||||
|
||||
img_root_map = {k: k for k in dataset_URLs}
|
||||
img_root_map.update({
|
||||
# MMBench v1.0
|
||||
'MMBench_DEV_EN': 'MMBench',
|
||||
'MMBench_TEST_EN': 'MMBench',
|
||||
'MMBench_DEV_CN': 'MMBench',
|
||||
'MMBench_TEST_CN': 'MMBench',
|
||||
'MMBench': 'MMBench', # Internal Only
|
||||
'MMBench_CN': 'MMBench', # Internal Only
|
||||
# MMBench v1.1
|
||||
'MMBench_DEV_EN_V11': 'MMBench_V11',
|
||||
'MMBench_TEST_EN_V11': 'MMBench_V11',
|
||||
'MMBench_DEV_CN_V11': 'MMBench_V11',
|
||||
'MMBench_TEST_CN_V11': 'MMBench_V11',
|
||||
'MMBench_V11': 'MMBench_V11', # Internal Only
|
||||
'MMBench_CN_V11': 'MMBench_V11', # Internal Only
|
||||
'COCO_VAL': 'COCO',
|
||||
'OCRVQA_TEST': 'OCRVQA',
|
||||
'OCRVQA_TESTCORE': 'OCRVQA',
|
||||
'TextVQA_VAL': 'TextVQA',
|
||||
'MMMU_DEV_VAL': 'MMMU',
|
||||
'MMMU_TEST': 'MMMU',
|
||||
'MathVista_MINI': 'MathVista',
|
||||
'HallusionBench': 'Hallusion',
|
||||
'DocVQA_VAL': 'DocVQA',
|
||||
'DocVQA_TEST': 'DocVQA_TEST',
|
||||
'OCRBench': 'OCRBench',
|
||||
'ChartQA_TEST': 'ChartQA_TEST',
|
||||
'InfoVQA_VAL': 'InfoVQA_VAL',
|
||||
'InfoVQA_TEST': 'InfoVQA_TEST',
|
||||
'MMStar': 'MMStar',
|
||||
'RealWorldQA': 'RealWorldQA',
|
||||
'POPE': 'POPE',
|
||||
})
|
||||
|
||||
assert set(dataset_URLs) == set(img_root_map)
|
||||
|
||||
|
||||
def DATASET_TYPE(dataset):
|
||||
# Dealing with Custom Dataset
|
||||
dataset = dataset.lower()
|
||||
if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d', 'mmstar', 'realworldqa'], dataset):
|
||||
return 'multi-choice'
|
||||
elif listinstr(['mme', 'hallusion', 'pope'], dataset):
|
||||
return 'Y/N'
|
||||
elif 'coco' in dataset:
|
||||
return 'Caption'
|
||||
elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'infovqa', 'llavabench',
|
||||
'mmvet', 'ocrbench'], dataset):
|
||||
return 'VQA'
|
||||
else:
|
||||
if dataset not in dataset_URLs:
|
||||
import warnings
|
||||
warnings.warn(f"Dataset {dataset} not found in dataset_URLs, will use 'multi-choice' as the default TYPE.")
|
||||
return 'multi-choice'
|
||||
else:
|
||||
return 'QA'
|
||||
|
||||
|
||||
def abbr2full(s):
|
||||
datasets = [x for x in img_root_map]
|
||||
ins = [s in d for d in datasets]
|
||||
if sum(ins) == 1:
|
||||
for d in datasets:
|
||||
if s in d:
|
||||
return d
|
||||
else:
|
||||
return s
|
||||
@@ -145,7 +145,7 @@ def track_progress_rich(func: Callable,
|
||||
results = []
|
||||
for task in tasks:
|
||||
result, idx = worker(task)
|
||||
results.append(worker(task)[0])
|
||||
results.append(result)
|
||||
if save is not None:
|
||||
with portalocker.Lock(save, timeout=5) as fh:
|
||||
ans = load(save)
|
||||
|
||||
97
eval_mm/vlmevalkit/vlmeval/utils/result_transfer.py
Normal file
97
eval_mm/vlmevalkit/vlmeval/utils/result_transfer.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from ..smp import *
|
||||
from ..dataset.utils.judge_util import build_judge
|
||||
from ..dataset.utils.multiple_choice import extract_answer_from_item
|
||||
from .matching_util import can_infer
|
||||
from .mp_util import track_progress_rich
|
||||
|
||||
|
||||
def MMMU_result_transfer(result_path):
|
||||
res = {}
|
||||
result_data = load(result_path)
|
||||
mcq = result_data['A'].notna()
|
||||
lt = len(result_data)
|
||||
for i in range(lt):
|
||||
line = result_data.iloc[i]
|
||||
if mcq[i]:
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
prediction = line['prediction']
|
||||
infer_prediction = can_infer(prediction, options)
|
||||
res[line['id']] = infer_prediction
|
||||
else:
|
||||
res[line['id']] = line['prediction']
|
||||
result_json = result_path.replace('.xlsx', '.json')
|
||||
dump(res, result_json)
|
||||
return result_json
|
||||
|
||||
|
||||
def MMTBench_result_transfer(eval_file, dataset='default', **judge_kwargs):
|
||||
logger = get_logger('Evaluation')
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
rd.seed(2680)
|
||||
suffix = eval_file.split('.')[-1]
|
||||
model = judge_kwargs['model']
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
name_str_map = {
|
||||
'chatgpt-0125': 'openai',
|
||||
'gpt-4-0125': 'gpt4'
|
||||
}
|
||||
name_str = name_str_map[model] if model in name_str_map else model
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
logger.error('The OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
else:
|
||||
logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
|
||||
logger.info(f'Evaluating {eval_file}')
|
||||
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_option.pkl')
|
||||
result = {}
|
||||
if osp.exists(result_file):
|
||||
result = load(result_file)
|
||||
|
||||
data = load(eval_file)
|
||||
assert 'index' in data, 'Essentail columns missing in the eval_file.'
|
||||
|
||||
data = data.sort_values(by='index')
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
for k in data.keys():
|
||||
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
||||
|
||||
idx2lines = {data.iloc[i]['index']: data.iloc[i] for i in range(len(data))}
|
||||
idx2lines = {k: v for k, v in idx2lines.items() if k not in result}
|
||||
|
||||
indices = list(idx2lines.keys())
|
||||
lines = [idx2lines[i] for i in indices]
|
||||
tups = [(model, line) for line in lines]
|
||||
res = track_progress_rich(
|
||||
extract_answer_from_item,
|
||||
tups,
|
||||
nproc=nproc,
|
||||
chunksize=nproc,
|
||||
save=result_file,
|
||||
keys=indices)
|
||||
|
||||
for i, r in zip(indices, res):
|
||||
if i in result:
|
||||
assert result[i]['opt'] == r['opt'] and result[i]['log'] == r['log']
|
||||
else:
|
||||
result[i] = r
|
||||
|
||||
indices = list(data['index'])
|
||||
data['opt'] = [result[i]['opt'] for i in data['index']]
|
||||
data['log'] = [result[i]['log'] for i in data['index']]
|
||||
|
||||
# load split
|
||||
output_path = eval_file.replace(f'.{suffix}', f'_{name_str}_submission.tsv')
|
||||
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_submission.tsv'))
|
||||
return output_path
|
||||
Reference in New Issue
Block a user