mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-V 2.6
This commit is contained in:
123
eval_mm/vlmevalkit/vlmeval/dataset/text_mcq.py
Normal file
123
eval_mm/vlmevalkit/vlmeval/dataset/text_mcq.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from .text_base import TextBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from ..smp import *
|
||||
|
||||
|
||||
class TextMCQDataset(TextBaseDataset):
|
||||
TYPE = 'MCQ'
|
||||
|
||||
DATASET_URL = {}
|
||||
|
||||
DATASET_MD5 = {}
|
||||
|
||||
def build_prompt(self, line):
|
||||
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question = line['question']
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
options_prompt = 'Options:\n'
|
||||
for key, item in options.items():
|
||||
options_prompt += f'{key}. {item}\n'
|
||||
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
||||
prompt = ''
|
||||
if hint is not None:
|
||||
prompt += f'Hint: {hint}\n'
|
||||
prompt += f'Question: {question}\n'
|
||||
if len(options):
|
||||
prompt += options_prompt
|
||||
prompt += 'Please select the correct answer from the options above. \n'
|
||||
|
||||
msgs = []
|
||||
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
|
||||
return msgs
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
|
||||
# assert dataset is not None
|
||||
dataset_map = {
|
||||
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
|
||||
'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
|
||||
}
|
||||
dataset = self.dataset_name
|
||||
if dataset in dataset_map:
|
||||
dataset = dataset_map[dataset]
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
circular = False
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
||||
name_str = name_str_map[model] if model in name_str_map else model
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
|
||||
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
||||
|
||||
data = load(eval_file)
|
||||
data = data.sort_values(by='index')
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
# If not choice label, then use lower case
|
||||
for k in data.keys():
|
||||
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
||||
|
||||
meta = self.data
|
||||
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
||||
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
||||
for k in data_map:
|
||||
assert k in meta_q_map, (
|
||||
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
||||
)
|
||||
|
||||
if circular:
|
||||
data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
||||
else:
|
||||
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
||||
|
||||
# load split
|
||||
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
|
||||
# May have different report acc functions for different datasets
|
||||
if 'MMT' in dataset:
|
||||
acc = report_acc_MMT(data)
|
||||
else:
|
||||
acc = report_acc(data)
|
||||
|
||||
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
dump(acc, score_file)
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
class CustomTextMCQDataset(TextMCQDataset):
|
||||
|
||||
def load_data(self, dataset):
|
||||
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
||||
|
||||
if file_size(data_path, 'GB') > 1:
|
||||
local_path = data_path.replace('.tsv', '_local.tsv')
|
||||
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
|
||||
from ..tools import LOCALIZE
|
||||
LOCALIZE(data_path, local_path)
|
||||
data_path = local_path
|
||||
return load(data_path)
|
||||
Reference in New Issue
Block a user