mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Add eval_mm dir
This commit is contained in:
12
eval_mm/vlmevalkit/vlmeval/utils/__init__.py
Normal file
12
eval_mm/vlmevalkit/vlmeval/utils/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .matching_util import can_infer, can_infer_option, can_infer_text
|
||||
from .mp_util import track_progress_rich
|
||||
from .custom_prompt import CustomPrompt
|
||||
from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full
|
||||
from .dataset import TSVDataset, split_MMMU, MMMU_result_transfer
|
||||
|
||||
|
||||
__all__ = [
|
||||
'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
|
||||
'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt',
|
||||
'split_MMMU', 'abbr2full'
|
||||
]
|
||||
33
eval_mm/vlmevalkit/vlmeval/utils/custom_prompt.py
Normal file
33
eval_mm/vlmevalkit/vlmeval/utils/custom_prompt.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from ..smp import *
|
||||
from .dataset_config import img_root_map
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class CustomPrompt:
|
||||
|
||||
@abstractmethod
|
||||
def use_custom_prompt(self, dataset):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build_prompt(self, line, dataset):
|
||||
raise NotImplementedError
|
||||
|
||||
def dump_image(self, line, dataset):
|
||||
ROOT = LMUDataRoot()
|
||||
assert isinstance(dataset, str)
|
||||
img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
|
||||
os.makedirs(img_root, exist_ok=True)
|
||||
if isinstance(line['image'], list):
|
||||
tgt_path = []
|
||||
assert 'image_path' in line
|
||||
for img, im_name in zip(line['image'], line['image_path']):
|
||||
path = osp.join(img_root, im_name)
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(img, path)
|
||||
tgt_path.append(path)
|
||||
else:
|
||||
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
|
||||
if not read_ok(tgt_path):
|
||||
decode_base64_to_image_file(line['image'], tgt_path)
|
||||
return tgt_path
|
||||
190
eval_mm/vlmevalkit/vlmeval/utils/dataset.py
Normal file
190
eval_mm/vlmevalkit/vlmeval/utils/dataset.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import pandas as pd
|
||||
import hashlib
|
||||
from ..smp import *
|
||||
from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE
|
||||
from .custom_prompt import CustomPrompt
|
||||
from .matching_util import can_infer
|
||||
|
||||
|
||||
def isliststr(s):
|
||||
return (s[0] == '[') and (s[-1] == ']')
|
||||
|
||||
|
||||
def check_md5(data_path, dataset):
|
||||
if dataset not in dataset_md5_dict:
|
||||
warnings.warn(f'We do not have an md5 record for dataset {dataset}, skip the md5 check. ')
|
||||
return True
|
||||
assert osp.exists(data_path)
|
||||
with open(data_path, 'rb') as f:
|
||||
hash = hashlib.new('md5')
|
||||
for chunk in iter(lambda: f.read(2**20), b''):
|
||||
hash.update(chunk)
|
||||
if str(hash.hexdigest()) == dataset_md5_dict[dataset]:
|
||||
return True
|
||||
else:
|
||||
warnings.warn('this data file is incomplete, so it needs to be downloaded again.')
|
||||
return False
|
||||
|
||||
|
||||
def split_MMMU(msgs):
|
||||
text, images = None, []
|
||||
for s in msgs:
|
||||
if s['type'] == 'image':
|
||||
images.append(s['value'])
|
||||
elif s['type'] == 'text':
|
||||
assert text is None
|
||||
text = s['value']
|
||||
text_segs = text.split('<image ')
|
||||
segs = [dict(type='text', value=text_segs[0])]
|
||||
for i, seg in enumerate(text_segs):
|
||||
if i == 0:
|
||||
continue
|
||||
assert istype(seg[0], int) and seg[1] == '>'
|
||||
image_idx = int(seg[0]) - 1
|
||||
segs.append(dict(type='image', value=images[image_idx]))
|
||||
segs.append(dict(type='text', value=seg[2:]))
|
||||
return segs
|
||||
|
||||
|
||||
def MMMU_result_transfer(result_path):
|
||||
res = {}
|
||||
result_data = load(result_path)
|
||||
mcq = result_data['A'].notna()
|
||||
lt = len(result_data)
|
||||
for i in range(lt):
|
||||
line = result_data.iloc[i]
|
||||
if mcq[i]:
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
prediction = line['prediction']
|
||||
infer_prediction = can_infer(prediction, options)
|
||||
res[line['id']] = infer_prediction
|
||||
else:
|
||||
res[line['id']] = line['prediction']
|
||||
result_json = result_path.replace('.xlsx', '.json')
|
||||
dump(res, result_json)
|
||||
return result_json
|
||||
|
||||
|
||||
class TSVDataset(CustomPrompt):
|
||||
|
||||
def __init__(self, dataset='MMBench', skip_noimg=True):
|
||||
|
||||
self.data_root = LMUDataRoot()
|
||||
assert osp.exists(self.data_root)
|
||||
|
||||
self.dataset = dataset
|
||||
self.dataset_type = DATASET_TYPE(dataset)
|
||||
|
||||
if dataset in dataset_URLs:
|
||||
url = dataset_URLs[dataset]
|
||||
file_name = url.split('/')[-1]
|
||||
data_path = osp.join(self.data_root, file_name)
|
||||
|
||||
if osp.exists(data_path) and check_md5(data_path, dataset):
|
||||
pass
|
||||
elif osp.isfile(url):
|
||||
# If url is actually a file path, use it directly
|
||||
data_path = url
|
||||
else:
|
||||
warnings.warn('The dataset tsv is not downloaded')
|
||||
download_file(url, data_path)
|
||||
else:
|
||||
data_path = osp.join(self.data_root, dataset + '.tsv')
|
||||
assert osp.exists(data_path)
|
||||
|
||||
data = load(data_path)
|
||||
self.skip_noimg = skip_noimg
|
||||
if skip_noimg and 'image' in data:
|
||||
data = data[~pd.isna(data['image'])]
|
||||
|
||||
# Prompt for Captioning
|
||||
if listinstr(['COCO'], dataset):
|
||||
data['question'] = [(
|
||||
'Please describe this image in general. Directly provide the description, '
|
||||
'do not include prefix like "This image depicts". '
|
||||
)] * len(data)
|
||||
|
||||
data['index'] = [str(x) for x in data['index']]
|
||||
|
||||
self.meta_only = True
|
||||
if 'image' in data:
|
||||
data['image'] = [str(x) for x in data['image']]
|
||||
|
||||
image_map = {x: y for x, y in zip(data['index'], data['image'])}
|
||||
for k in image_map:
|
||||
if len(image_map[k]) <= 64:
|
||||
idx = image_map[k]
|
||||
assert idx in image_map and len(image_map[idx]) > 64
|
||||
image_map[k] = image_map[idx]
|
||||
|
||||
data['image'] = [
|
||||
eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
|
||||
for k in data['index']
|
||||
]
|
||||
self.meta_only = False
|
||||
|
||||
if 'image_path' in data:
|
||||
data['image_path'] = [
|
||||
eval(pths) if isliststr(pths) else pths for pths in data['image_path']
|
||||
]
|
||||
|
||||
if np.all([istype(x, int) for x in data['index']]):
|
||||
data['index'] = [int(x) for x in data['index']]
|
||||
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def build_prompt(self, line, dataset=None):
|
||||
if dataset is None:
|
||||
dataset = self.dataset
|
||||
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = line['image_path']
|
||||
else:
|
||||
tgt_path = self.dump_image(line, dataset)
|
||||
|
||||
prompt = line['question']
|
||||
if DATASET_TYPE(dataset) == 'multi-choice':
|
||||
question = line['question']
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
options_prompt = 'Options:\n'
|
||||
for key, item in options.items():
|
||||
options_prompt += f'{key}. {item}\n'
|
||||
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
||||
prompt = ''
|
||||
if hint is not None:
|
||||
prompt += f'Hint: {hint}\n'
|
||||
prompt += f'Question: {question}\n'
|
||||
if len(options):
|
||||
prompt += options_prompt
|
||||
prompt += 'Please select the correct answer from the options above. \n'
|
||||
elif DATASET_TYPE(dataset) == 'VQA':
|
||||
if listinstr(['ocrvqa', 'textvqa', 'chartqa', 'docvqa'], dataset.lower()):
|
||||
prompt += '\nPlease try to answer the question with short words or phrases if possible\n.'
|
||||
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
|
||||
return msgs
|
||||
|
||||
def display(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
mmqa_display(line)
|
||||
158
eval_mm/vlmevalkit/vlmeval/utils/dataset_config.py
Normal file
158
eval_mm/vlmevalkit/vlmeval/utils/dataset_config.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from ..smp import listinstr
|
||||
|
||||
dataset_URLs = {
|
||||
# MMBench v1.0
|
||||
'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv',
|
||||
'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv',
|
||||
'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv',
|
||||
'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv',
|
||||
'MMBench': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv', # Internal Only
|
||||
'MMBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv', # Internal Only
|
||||
# MMBench v1.1
|
||||
'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN_V11.tsv',
|
||||
'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN_V11.tsv',
|
||||
'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN_V11.tsv',
|
||||
'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN_V11.tsv',
|
||||
'MMBench_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_V11.tsv', # Internal Only
|
||||
'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN_V11.tsv', # Internal Only
|
||||
# CCBench
|
||||
'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
|
||||
'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
|
||||
'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv',
|
||||
'CORE_MM': 'https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv',
|
||||
'MMVet': 'https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv',
|
||||
'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
|
||||
'OCRVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv',
|
||||
'OCRVQA_TESTCORE': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv',
|
||||
'TextVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv',
|
||||
'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
|
||||
'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
|
||||
'MathVista_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv',
|
||||
'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv',
|
||||
'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv',
|
||||
'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
|
||||
'DocVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv',
|
||||
'DocVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_TEST.tsv',
|
||||
'InfoVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_VAL.tsv',
|
||||
'InfoVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_TEST.tsv',
|
||||
'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
|
||||
'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv',
|
||||
'OCRBench': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
|
||||
'ChartQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ChartQA_TEST.tsv',
|
||||
'MMStar': 'https://opencompass.openxlab.space/utils/VLMEval/MMStar.tsv',
|
||||
'RealWorldQA': 'https://opencompass.openxlab.space/utils/VLMEval/RealWorldQA.tsv',
|
||||
'POPE': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
|
||||
}
|
||||
|
||||
dataset_md5_dict = {
|
||||
# MMBench v1.0
|
||||
'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
|
||||
'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
|
||||
'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
|
||||
'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
|
||||
'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Internal Only
|
||||
'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Internal Only
|
||||
# MMBench v1.1
|
||||
'MMBench_DEV_EN_V11': '30c05be8f2f347a50be25aa067248184',
|
||||
'MMBench_TEST_EN_V11': '26f0f15381a21720255091d3e0316ce6',
|
||||
'MMBench_DEV_CN_V11': '593f9b5f6bea453d870a798b34ae4f37',
|
||||
'MMBench_TEST_CN_V11': '74bbe4556dac745613c7cbe5ad787050',
|
||||
'MMBench_V11': 'b9276414f57af1308dcc4d0cd9b42e7c', # Internal Only
|
||||
'MMBench_CN_V11': '95f6980dd1b4de38e3cbffe0305a3f25', # Internal Only
|
||||
# CCBench
|
||||
'CCBench': '1de88b4257e7eee3f60b18d45eda6f07',
|
||||
'MME': 'b36b43c3f09801f5d368627fb92187c3',
|
||||
'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
|
||||
'CORE_MM': '8a8da2f2232e79caf98415bfdf0a202d',
|
||||
'MMVet': '748aa6d4aa9d4de798306a63718455e3',
|
||||
'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
|
||||
'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
|
||||
'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
|
||||
'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
|
||||
'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
|
||||
'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
|
||||
'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464',
|
||||
'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
|
||||
'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
|
||||
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
|
||||
'DocVQA_VAL': 'd5ee77e1926ff10690d469c56b73eabf',
|
||||
'DocVQA_TEST': '6a2f28cac26ef2d3447374e8c6f6c8e9',
|
||||
'InfoVQA_VAL': '2342e9c225222f0ef4dec545ebb126fe',
|
||||
'InfoVQA_TEST': 'df535bf51b88dc9718252c34131a6227',
|
||||
'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
|
||||
'LLaVABench': 'd382a093f749a697820d3dadd61c8428',
|
||||
'OCRBench': 'e953d98a987cc6e26ef717b61260b778',
|
||||
'ChartQA_TEST': 'c902e0aa9be5582a7aad6dcf52734b42',
|
||||
'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
|
||||
'RealWorldQA': '92321028d2bc29040284b6674721e48f',
|
||||
'POPE': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
|
||||
}
|
||||
|
||||
img_root_map = {k: k for k in dataset_URLs}
|
||||
img_root_map.update({
|
||||
# MMBench v1.0
|
||||
'MMBench_DEV_EN': 'MMBench',
|
||||
'MMBench_TEST_EN': 'MMBench',
|
||||
'MMBench_DEV_CN': 'MMBench',
|
||||
'MMBench_TEST_CN': 'MMBench',
|
||||
'MMBench': 'MMBench', # Internal Only
|
||||
'MMBench_CN': 'MMBench', # Internal Only
|
||||
# MMBench v1.1
|
||||
'MMBench_DEV_EN_V11': 'MMBench_V11',
|
||||
'MMBench_TEST_EN_V11': 'MMBench_V11',
|
||||
'MMBench_DEV_CN_V11': 'MMBench_V11',
|
||||
'MMBench_TEST_CN_V11': 'MMBench_V11',
|
||||
'MMBench_V11': 'MMBench_V11', # Internal Only
|
||||
'MMBench_CN_V11': 'MMBench_V11', # Internal Only
|
||||
'COCO_VAL': 'COCO',
|
||||
'OCRVQA_TEST': 'OCRVQA',
|
||||
'OCRVQA_TESTCORE': 'OCRVQA',
|
||||
'TextVQA_VAL': 'TextVQA',
|
||||
'MMMU_DEV_VAL': 'MMMU',
|
||||
'MMMU_TEST': 'MMMU',
|
||||
'MathVista_MINI': 'MathVista',
|
||||
'HallusionBench': 'Hallusion',
|
||||
'DocVQA_VAL': 'DocVQA',
|
||||
'DocVQA_TEST': 'DocVQA_TEST',
|
||||
'OCRBench': 'OCRBench',
|
||||
'ChartQA_TEST': 'ChartQA_TEST',
|
||||
'InfoVQA_VAL': 'InfoVQA_VAL',
|
||||
'InfoVQA_TEST': 'InfoVQA_TEST',
|
||||
'MMStar': 'MMStar',
|
||||
'RealWorldQA': 'RealWorldQA',
|
||||
'POPE': 'POPE',
|
||||
})
|
||||
|
||||
assert set(dataset_URLs) == set(img_root_map)
|
||||
|
||||
|
||||
def DATASET_TYPE(dataset):
|
||||
# Dealing with Custom Dataset
|
||||
dataset = dataset.lower()
|
||||
if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d', 'mmstar', 'realworldqa'], dataset):
|
||||
return 'multi-choice'
|
||||
elif listinstr(['mme', 'hallusion', 'pope'], dataset):
|
||||
return 'Y/N'
|
||||
elif 'coco' in dataset:
|
||||
return 'Caption'
|
||||
elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'infovqa', 'llavabench',
|
||||
'mmvet', 'ocrbench'], dataset):
|
||||
return 'VQA'
|
||||
else:
|
||||
if dataset not in dataset_URLs:
|
||||
import warnings
|
||||
warnings.warn(f"Dataset {dataset} not found in dataset_URLs, will use 'multi-choice' as the default TYPE.")
|
||||
return 'multi-choice'
|
||||
else:
|
||||
return 'QA'
|
||||
|
||||
|
||||
def abbr2full(s):
|
||||
datasets = [x for x in img_root_map]
|
||||
ins = [s in d for d in datasets]
|
||||
if sum(ins) == 1:
|
||||
for d in datasets:
|
||||
if s in d:
|
||||
return d
|
||||
else:
|
||||
return s
|
||||
69
eval_mm/vlmevalkit/vlmeval/utils/matching_util.py
Normal file
69
eval_mm/vlmevalkit/vlmeval/utils/matching_util.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import string
|
||||
import copy as cp
|
||||
import os
|
||||
from ..smp import *
|
||||
|
||||
|
||||
def can_infer_option(answer, choices):
|
||||
verbose = os.environ.get('VERBOSE', 0)
|
||||
# Choices is a dictionary
|
||||
if 'Failed to obtain answer via API' in answer:
|
||||
return False
|
||||
|
||||
reject_to_answer = [
|
||||
"Sorry, I can't help with images of people yet.",
|
||||
"I can't process this file.",
|
||||
"I'm sorry, but without the image provided",
|
||||
'Cannot determine the answer'
|
||||
]
|
||||
for err in reject_to_answer:
|
||||
if err in answer:
|
||||
return 'Z'
|
||||
|
||||
def count_choice(splits, choices, prefix='', suffix=''):
|
||||
cnt = 0
|
||||
for c in choices:
|
||||
if prefix + c + suffix in splits:
|
||||
cnt += 1
|
||||
return cnt
|
||||
|
||||
answer_mod = cp.copy(answer)
|
||||
chars = '.()[],:;!*#{}'
|
||||
for c in chars:
|
||||
answer_mod = answer_mod.replace(c, ' ')
|
||||
|
||||
splits = [x.strip() for x in answer_mod.split()]
|
||||
count = count_choice(splits, choices)
|
||||
|
||||
if count == 1:
|
||||
for ch in choices:
|
||||
if 'A' in splits and len(splits) > 3 and verbose:
|
||||
logger = get_logger('Evaluation')
|
||||
logger.info(f'A might be a quantifier in the string: {answer}.')
|
||||
return False
|
||||
if ch in splits:
|
||||
return ch
|
||||
elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
|
||||
return 'Z'
|
||||
return False
|
||||
|
||||
|
||||
def can_infer_text(answer, choices):
|
||||
answer = answer.lower()
|
||||
assert isinstance(choices, dict)
|
||||
for k in choices:
|
||||
assert k in string.ascii_uppercase
|
||||
choices[k] = str(choices[k]).lower()
|
||||
cands = []
|
||||
for k in choices:
|
||||
if choices[k] in answer:
|
||||
cands.append(k)
|
||||
if len(cands) == 1:
|
||||
return cands[0]
|
||||
return False
|
||||
|
||||
|
||||
def can_infer(answer, choices):
|
||||
answer = str(answer)
|
||||
copt = can_infer_option(answer, choices)
|
||||
return copt if copt else can_infer_text(answer, choices)
|
||||
191
eval_mm/vlmevalkit/vlmeval/utils/mp_util.py
Normal file
191
eval_mm/vlmevalkit/vlmeval/utils/mp_util.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from multiprocessing import Pool
|
||||
import os
|
||||
from typing import Callable, Iterable, Sized
|
||||
|
||||
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
|
||||
TaskProgressColumn, TextColumn, TimeRemainingColumn)
|
||||
from rich.text import Text
|
||||
import os.path as osp
|
||||
import portalocker
|
||||
from ..smp import load, dump
|
||||
|
||||
|
||||
class _Worker:
|
||||
"""Function wrapper for ``track_progress_rich``"""
|
||||
|
||||
def __init__(self, func) -> None:
|
||||
self.func = func
|
||||
|
||||
def __call__(self, inputs):
|
||||
inputs, idx = inputs
|
||||
if not isinstance(inputs, (tuple, list, dict)):
|
||||
inputs = (inputs, )
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
return self.func(**inputs), idx
|
||||
else:
|
||||
return self.func(*inputs), idx
|
||||
|
||||
|
||||
class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
|
||||
"""Skip calculating remaining time for the first few times.
|
||||
|
||||
Args:
|
||||
skip_times (int): The number of times to skip. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, skip_times=0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.skip_times = skip_times
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""Show time remaining."""
|
||||
if task.completed <= self.skip_times:
|
||||
return Text('-:--:--', style='progress.remaining')
|
||||
return super().render(task)
|
||||
|
||||
|
||||
def _tasks_with_index(tasks):
|
||||
"""Add index to tasks."""
|
||||
for idx, task in enumerate(tasks):
|
||||
yield task, idx
|
||||
|
||||
|
||||
def track_progress_rich(func: Callable,
|
||||
tasks: Iterable = tuple(),
|
||||
task_num: int = None,
|
||||
nproc: int = 1,
|
||||
chunksize: int = 1,
|
||||
description: str = 'Processing',
|
||||
save=None, keys=None,
|
||||
color: str = 'blue') -> list:
|
||||
"""Track the progress of parallel task execution with a progress bar. The
|
||||
built-in :mod:`multiprocessing` module is used for process pools and tasks
|
||||
are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
|
||||
|
||||
Args:
|
||||
func (callable): The function to be applied to each task.
|
||||
tasks (Iterable or Sized): A tuple of tasks. There are several cases
|
||||
for different format tasks:
|
||||
- When ``func`` accepts no arguments: tasks should be an empty
|
||||
tuple, and ``task_num`` must be specified.
|
||||
- When ``func`` accepts only one argument: tasks should be a tuple
|
||||
containing the argument.
|
||||
- When ``func`` accepts multiple arguments: tasks should be a
|
||||
tuple, with each element representing a set of arguments.
|
||||
If an element is a ``dict``, it will be parsed as a set of
|
||||
keyword-only arguments.
|
||||
Defaults to an empty tuple.
|
||||
task_num (int, optional): If ``tasks`` is an iterator which does not
|
||||
have length, the number of tasks can be provided by ``task_num``.
|
||||
Defaults to None.
|
||||
nproc (int): Process (worker) number, if nuproc is 1,
|
||||
use single process. Defaults to 1.
|
||||
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
|
||||
Defaults to 1.
|
||||
description (str): The description of progress bar.
|
||||
Defaults to "Process".
|
||||
color (str): The color of progress bar. Defaults to "blue".
|
||||
|
||||
Examples:
|
||||
>>> import time
|
||||
|
||||
>>> def func(x):
|
||||
... time.sleep(1)
|
||||
... return x**2
|
||||
>>> track_progress_rich(func, range(10), nproc=2)
|
||||
|
||||
Returns:
|
||||
list: The task results.
|
||||
"""
|
||||
if save is not None:
|
||||
assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
|
||||
if not osp.exists(save):
|
||||
dump({}, save)
|
||||
if keys is not None:
|
||||
assert len(keys) == len(tasks)
|
||||
|
||||
if not callable(func):
|
||||
raise TypeError('func must be a callable object')
|
||||
if not isinstance(tasks, Iterable):
|
||||
raise TypeError(
|
||||
f'tasks must be an iterable object, but got {type(tasks)}')
|
||||
if isinstance(tasks, Sized):
|
||||
if len(tasks) == 0:
|
||||
if task_num is None:
|
||||
raise ValueError('If tasks is an empty iterable, '
|
||||
'task_num must be set')
|
||||
else:
|
||||
tasks = tuple(tuple() for _ in range(task_num))
|
||||
else:
|
||||
if task_num is not None and task_num != len(tasks):
|
||||
raise ValueError('task_num does not match the length of tasks')
|
||||
task_num = len(tasks)
|
||||
|
||||
if nproc <= 0:
|
||||
raise ValueError('nproc must be a positive number')
|
||||
|
||||
skip_times = nproc * chunksize if nproc > 1 else 0
|
||||
prog_bar = Progress(
|
||||
TextColumn('{task.description}'),
|
||||
BarColumn(),
|
||||
_SkipFirstTimeRemainingColumn(skip_times=skip_times),
|
||||
MofNCompleteColumn(),
|
||||
TaskProgressColumn(show_speed=True),
|
||||
)
|
||||
|
||||
worker = _Worker(func)
|
||||
task_id = prog_bar.add_task(
|
||||
total=task_num, color=color, description=description)
|
||||
tasks = _tasks_with_index(tasks)
|
||||
|
||||
# Use single process when nproc is 1, else use multiprocess.
|
||||
with prog_bar:
|
||||
if nproc == 1:
|
||||
results = []
|
||||
for task in tasks:
|
||||
result, idx = worker(task)
|
||||
results.append(worker(task)[0])
|
||||
if save is not None:
|
||||
with portalocker.Lock(save, timeout=5) as fh:
|
||||
ans = load(save)
|
||||
ans[keys[idx]] = result
|
||||
|
||||
if os.environ.get('VERBOSE', True):
|
||||
print(keys[idx], result, flush=True)
|
||||
|
||||
dump(ans, save)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
|
||||
prog_bar.update(task_id, advance=1, refresh=True)
|
||||
else:
|
||||
with Pool(nproc) as pool:
|
||||
results = []
|
||||
unordered_results = []
|
||||
gen = pool.imap_unordered(worker, tasks, chunksize)
|
||||
try:
|
||||
for result in gen:
|
||||
result, idx = result
|
||||
unordered_results.append((result, idx))
|
||||
|
||||
if save is not None:
|
||||
with portalocker.Lock(save, timeout=5) as fh:
|
||||
ans = load(save)
|
||||
ans[keys[idx]] = result
|
||||
|
||||
if os.environ.get('VERBOSE', False):
|
||||
print(keys[idx], result, flush=True)
|
||||
|
||||
dump(ans, save)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
|
||||
results.append(None)
|
||||
prog_bar.update(task_id, advance=1, refresh=True)
|
||||
except Exception as e:
|
||||
prog_bar.stop()
|
||||
raise e
|
||||
for result, idx in unordered_results:
|
||||
results[idx] = result
|
||||
return results
|
||||
Reference in New Issue
Block a user