mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
124 lines
4.4 KiB
Python
124 lines
4.4 KiB
Python
from .text_base import TextBaseDataset
|
|
from .utils import build_judge, DEBUG_MESSAGE
|
|
from ..smp import *
|
|
|
|
|
|
class TextMCQDataset(TextBaseDataset):
|
|
TYPE = 'MCQ'
|
|
|
|
DATASET_URL = {}
|
|
|
|
DATASET_MD5 = {}
|
|
|
|
def build_prompt(self, line):
|
|
|
|
if isinstance(line, int):
|
|
line = self.data.iloc[line]
|
|
|
|
question = line['question']
|
|
options = {
|
|
cand: line[cand]
|
|
for cand in string.ascii_uppercase
|
|
if cand in line and not pd.isna(line[cand])
|
|
}
|
|
options_prompt = 'Options:\n'
|
|
for key, item in options.items():
|
|
options_prompt += f'{key}. {item}\n'
|
|
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
|
prompt = ''
|
|
if hint is not None:
|
|
prompt += f'Hint: {hint}\n'
|
|
prompt += f'Question: {question}\n'
|
|
if len(options):
|
|
prompt += options_prompt
|
|
prompt += 'Please select the correct answer from the options above. \n'
|
|
|
|
msgs = []
|
|
|
|
msgs.append(dict(type='text', value=prompt))
|
|
|
|
return msgs
|
|
|
|
def evaluate(self, eval_file, **judge_kwargs):
|
|
from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
|
|
# assert dataset is not None
|
|
dataset_map = {
|
|
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
|
|
'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
|
|
}
|
|
dataset = self.dataset_name
|
|
if dataset in dataset_map:
|
|
dataset = dataset_map[dataset]
|
|
nproc = judge_kwargs.pop('nproc', 4)
|
|
|
|
circular = False
|
|
|
|
suffix = eval_file.split('.')[-1]
|
|
model = judge_kwargs.get('model', 'exact_matching')
|
|
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
|
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
|
name_str = name_str_map[model] if model in name_str_map else model
|
|
|
|
if model == 'exact_matching':
|
|
model = None
|
|
elif gpt_key_set():
|
|
model = build_judge(**judge_kwargs)
|
|
if not model.working():
|
|
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
|
warnings.warn(DEBUG_MESSAGE)
|
|
model = None
|
|
else:
|
|
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
|
model = None
|
|
|
|
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
|
|
|
data = load(eval_file)
|
|
data = data.sort_values(by='index')
|
|
data['prediction'] = [str(x) for x in data['prediction']]
|
|
# If not choice label, then use lower case
|
|
for k in data.keys():
|
|
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
|
|
|
meta = self.data
|
|
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
|
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
|
for k in data_map:
|
|
assert k in meta_q_map, (
|
|
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
|
)
|
|
|
|
if circular:
|
|
data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
|
else:
|
|
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
|
|
|
# load split
|
|
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
|
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
|
|
|
# May have different report acc functions for different datasets
|
|
if 'MMT' in dataset:
|
|
acc = report_acc_MMT(data)
|
|
else:
|
|
acc = report_acc(data)
|
|
|
|
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
|
dump(acc, score_file)
|
|
|
|
return acc
|
|
|
|
|
|
class CustomTextMCQDataset(TextMCQDataset):
|
|
|
|
def load_data(self, dataset):
|
|
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
|
|
|
if file_size(data_path, 'GB') > 1:
|
|
local_path = data_path.replace('.tsv', '_local.tsv')
|
|
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
|
|
from ..tools import LOCALIZE
|
|
LOCALIZE(data_path, local_path)
|
|
data_path = local_path
|
|
return load(data_path)
|