mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 02:09:20 +08:00
341 lines
12 KiB
Python
341 lines
12 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Partly adopted from https://github.com/GT-Vision-Lab/VQA
|
|
# Copyright (c) 2014, Aishwarya Agrawal
|
|
|
|
import re
|
|
from vlmeval.smp import *
|
|
from typing import Optional
|
|
from functools import partial
|
|
|
|
|
|
def _process_digit_article(inText):
|
|
outText = []
|
|
tempText = inText.lower().split()
|
|
articles = ['a', 'an', 'the']
|
|
manualMap = {
|
|
'none': '0',
|
|
'zero': '0',
|
|
'one': '1',
|
|
'two': '2',
|
|
'three': '3',
|
|
'four': '4',
|
|
'five': '5',
|
|
'six': '6',
|
|
'seven': '7',
|
|
'eight': '8',
|
|
'nine': '9',
|
|
'ten': '10',
|
|
}
|
|
contractions = {
|
|
'aint': "ain't",
|
|
'arent': "aren't",
|
|
'cant': "can't",
|
|
'couldve': "could've",
|
|
'couldnt': "couldn't",
|
|
"couldn'tve": "couldn't've",
|
|
"couldnt've": "couldn't've",
|
|
'didnt': "didn't",
|
|
'doesnt': "doesn't",
|
|
'dont': "don't",
|
|
'hadnt': "hadn't",
|
|
"hadnt've": "hadn't've",
|
|
"hadn'tve": "hadn't've",
|
|
'hasnt': "hasn't",
|
|
'havent': "haven't",
|
|
'hed': "he'd",
|
|
"hed've": "he'd've",
|
|
"he'dve": "he'd've",
|
|
'hes': "he's",
|
|
'howd': "how'd",
|
|
'howll': "how'll",
|
|
'hows': "how's",
|
|
"Id've": "I'd've",
|
|
"I'dve": "I'd've",
|
|
'Im': "I'm",
|
|
'Ive': "I've",
|
|
'isnt': "isn't",
|
|
'itd': "it'd",
|
|
"itd've": "it'd've",
|
|
"it'dve": "it'd've",
|
|
'itll': "it'll",
|
|
"let's": "let's",
|
|
'maam': "ma'am",
|
|
'mightnt': "mightn't",
|
|
"mightnt've": "mightn't've",
|
|
"mightn'tve": "mightn't've",
|
|
'mightve': "might've",
|
|
'mustnt': "mustn't",
|
|
'mustve': "must've",
|
|
'neednt': "needn't",
|
|
'notve': "not've",
|
|
'oclock': "o'clock",
|
|
'oughtnt': "oughtn't",
|
|
"ow's'at": "'ow's'at",
|
|
"'ows'at": "'ow's'at",
|
|
"'ow'sat": "'ow's'at",
|
|
'shant': "shan't",
|
|
"shed've": "she'd've",
|
|
"she'dve": "she'd've",
|
|
"she's": "she's",
|
|
'shouldve': "should've",
|
|
'shouldnt': "shouldn't",
|
|
"shouldnt've": "shouldn't've",
|
|
"shouldn'tve": "shouldn't've",
|
|
"somebody'd": 'somebodyd',
|
|
"somebodyd've": "somebody'd've",
|
|
"somebody'dve": "somebody'd've",
|
|
'somebodyll': "somebody'll",
|
|
'somebodys': "somebody's",
|
|
'someoned': "someone'd",
|
|
"someoned've": "someone'd've",
|
|
"someone'dve": "someone'd've",
|
|
'someonell': "someone'll",
|
|
'someones': "someone's",
|
|
'somethingd': "something'd",
|
|
"somethingd've": "something'd've",
|
|
"something'dve": "something'd've",
|
|
'somethingll': "something'll",
|
|
'thats': "that's",
|
|
'thered': "there'd",
|
|
"thered've": "there'd've",
|
|
"there'dve": "there'd've",
|
|
'therere': "there're",
|
|
'theres': "there's",
|
|
'theyd': "they'd",
|
|
"theyd've": "they'd've",
|
|
"they'dve": "they'd've",
|
|
'theyll': "they'll",
|
|
'theyre': "they're",
|
|
'theyve': "they've",
|
|
'twas': "'twas",
|
|
'wasnt': "wasn't",
|
|
"wed've": "we'd've",
|
|
"we'dve": "we'd've",
|
|
'weve': "we've",
|
|
'werent': "weren't",
|
|
'whatll': "what'll",
|
|
'whatre': "what're",
|
|
'whats': "what's",
|
|
'whatve': "what've",
|
|
'whens': "when's",
|
|
'whered': "where'd",
|
|
'wheres': "where's",
|
|
'whereve': "where've",
|
|
'whod': "who'd",
|
|
"whod've": "who'd've",
|
|
"who'dve": "who'd've",
|
|
'wholl': "who'll",
|
|
'whos': "who's",
|
|
'whove': "who've",
|
|
'whyll': "why'll",
|
|
'whyre': "why're",
|
|
'whys': "why's",
|
|
'wont': "won't",
|
|
'wouldve': "would've",
|
|
'wouldnt': "wouldn't",
|
|
"wouldnt've": "wouldn't've",
|
|
"wouldn'tve": "wouldn't've",
|
|
'yall': "y'all",
|
|
"yall'll": "y'all'll",
|
|
"y'allll": "y'all'll",
|
|
"yall'd've": "y'all'd've",
|
|
"y'alld've": "y'all'd've",
|
|
"y'all'dve": "y'all'd've",
|
|
'youd': "you'd",
|
|
"youd've": "you'd've",
|
|
"you'dve": "you'd've",
|
|
'youll': "you'll",
|
|
'youre': "you're",
|
|
'youve': "you've",
|
|
}
|
|
for word in tempText:
|
|
word = manualMap.setdefault(word, word)
|
|
if word not in articles:
|
|
outText.append(word)
|
|
for wordId, word in enumerate(outText):
|
|
if word in contractions:
|
|
outText[wordId] = contractions[word]
|
|
outText = ' '.join(outText)
|
|
return outText
|
|
|
|
|
|
def hit_calculate(result, dataset_name, anls_threshold=0.5):
|
|
if listinstr(['TextVQA'], dataset_name):
|
|
return [np.mean(x['match']) for x in result]
|
|
elif listinstr(['DocVQA', 'InfoVQA'], dataset_name):
|
|
# return [1 - np.min(x['match']) >= anls_threshold for x in result]
|
|
return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result]
|
|
elif listinstr(['ChartQA', 'OCRVQA'], dataset_name):
|
|
return [np.max(x['match']) for x in result]
|
|
else: # default using vqa_score to calculate score
|
|
return [np.mean(x['match']) for x in result]
|
|
|
|
|
|
# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
|
|
def relaxed_correctness(target: str,
|
|
prediction: str,
|
|
max_relative_change: float = 0.05) -> bool:
|
|
"""Calculates relaxed correctness.
|
|
|
|
The correctness tolerates certain error ratio defined by max_relative_change.
|
|
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
|
|
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
|
|
numeric answers to allow a minor inaccuracy that may result from the automatic
|
|
data extraction process. We consider an answer to be correct if it is within
|
|
5% of the gold answer. For non-numeric answers, we still need an exact match
|
|
to consider an answer to be correct.”
|
|
|
|
Args:
|
|
target: Target string.
|
|
prediction: Predicted string.
|
|
max_relative_change: Maximum relative change.
|
|
|
|
Returns:
|
|
Whether the prediction was correct given the specified tolerance.
|
|
"""
|
|
|
|
def _to_float(text: str) -> Optional[float]:
|
|
try:
|
|
if text.endswith('%'):
|
|
# Convert percentages to floats.
|
|
return float(text.rstrip('%')) / 100.0
|
|
else:
|
|
return float(text)
|
|
except ValueError:
|
|
return None
|
|
prediction = str(prediction)
|
|
target = str(target)
|
|
prediction_float = _to_float(prediction)
|
|
target_float = _to_float(target)
|
|
if prediction_float is not None and target_float:
|
|
relative_change = abs(prediction_float - target_float) / abs(target_float)
|
|
return relative_change <= max_relative_change
|
|
else:
|
|
return prediction.lower() == target.lower()
|
|
|
|
|
|
def levenshtein_distance(s1, s2):
|
|
if len(s1) > len(s2):
|
|
s1, s2 = s2, s1
|
|
|
|
distances = range(len(s1) + 1)
|
|
for i2, c2 in enumerate(s2):
|
|
distances_ = [i2 + 1]
|
|
for i1, c1 in enumerate(s1):
|
|
if c1 == c2:
|
|
distances_.append(distances[i1])
|
|
else:
|
|
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
|
distances = distances_
|
|
return distances[-1]
|
|
|
|
|
|
def anls_compute(groundtruth, prediction):
|
|
gt_answer = ' '.join(groundtruth.strip().lower().split())
|
|
det_answer = ' '.join(prediction.strip().lower().split())
|
|
dist = levenshtein_distance(gt_answer, det_answer)
|
|
length = max(len(groundtruth.upper()), len(prediction.upper()))
|
|
values = 0.0 if length == 0 else float(dist) / float(length)
|
|
return values
|
|
|
|
|
|
def process_answer(answer):
|
|
answer = answer.replace('\n', ' ')
|
|
answer = answer.replace('\t', ' ')
|
|
answer = answer.strip()
|
|
answer = process_punctuation(answer)
|
|
answer = _process_digit_article(answer)
|
|
return answer
|
|
|
|
|
|
def process_line(line, method='vqa_score'):
|
|
ret = {}
|
|
if istype(line['answer'], list):
|
|
answers = eval(line['answer'])
|
|
else:
|
|
answers = [line['answer']]
|
|
if method == 'vqa_score':
|
|
ret['gt'] = [process_answer(x) for x in answers]
|
|
ret['pred'] = process_answer(line['prediction'])
|
|
ret['match'] = []
|
|
for current_idx, gtAnsDatum in enumerate(ret['gt']):
|
|
otherGTAns = [
|
|
item for ret_gt_idx, item in enumerate(ret['gt'])
|
|
if ret_gt_idx != current_idx
|
|
]
|
|
matchingAns = [
|
|
item for item in otherGTAns if item == ret['pred']
|
|
]
|
|
acc = min(1, float(len(matchingAns)) / 3)
|
|
ret['match'].append(acc)
|
|
elif method == 'anls':
|
|
ret['gt'] = answers
|
|
ret['pred'] = line['prediction']
|
|
ret['match'] = [anls_compute(x, ret['pred']) for x in ret['gt']]
|
|
elif method == 'relaxed_accuracy':
|
|
ret['gt'] = answers
|
|
ret['pred'] = line['prediction'].strip()
|
|
ret['match'] = [relaxed_correctness(ret['pred'], x) for x in ret['gt']]
|
|
elif method == 'accuracy':
|
|
ret['gt'] = answers
|
|
ret['pred'] = line['prediction'].strip()
|
|
ret['match'] = [(1.0 if (x.strip().lower() == ret['pred'].strip().lower()) else 0.0) for x in ret['gt']]
|
|
else: # default using vqa_score to calculate score
|
|
ret['gt'] = [process_answer(x) for x in answers]
|
|
ret['pred'] = process_answer(line['prediction'])
|
|
ret['match'] = [x == ret['pred'] for x in ret['gt']]
|
|
|
|
return ret
|
|
|
|
|
|
def VQAEval(eval_file, dataset_name, **kwargs):
|
|
logger = get_logger('Evaluation')
|
|
data = load(eval_file)
|
|
assert 'answer' in data and 'prediction' in data
|
|
data['prediction'] = [str(x) for x in data['prediction']]
|
|
data['answer'] = [str(x) for x in data['answer']]
|
|
lt = len(data)
|
|
pool = mp.Pool(16)
|
|
lines = [data.iloc[i] for i in range(lt)]
|
|
if listinstr(['TextVQA'], dataset_name):
|
|
res = pool.map(partial(process_line, method='vqa_score'), lines)
|
|
elif listinstr(['ChartQA'], dataset_name):
|
|
res = pool.map(partial(process_line, method='relaxed_accuracy'), lines)
|
|
elif listinstr(['OCRVQA'], dataset_name):
|
|
res = pool.map(partial(process_line, method='accuracy'), lines)
|
|
elif listinstr(['DocVQA', 'InfoVQA'], dataset_name):
|
|
res = pool.map(partial(process_line, method='anls'), lines)
|
|
else: # default using vqa_score to calculate score
|
|
res = pool.map(process_line, lines)
|
|
# [np.mean(x['match']) >= full_score_weight for x in res]
|
|
hit = hit_calculate(res, dataset_name)
|
|
ret = dict()
|
|
if 'split' in data:
|
|
splits = set(data['split'])
|
|
for sp in splits:
|
|
sub = [r for l, r in zip(lines, res) if l['split'] == sp]
|
|
# [np.mean(x['match']) >= full_score_weight for x in sub]
|
|
hit = hit_calculate(sub, dataset_name)
|
|
ret[sp] = np.mean(hit) * 100
|
|
sub = [r for l, r in zip(lines, res)]
|
|
hit = hit_calculate(sub, dataset_name)
|
|
ret['Overall'] = np.mean(hit) * 100
|
|
else:
|
|
ret['Overall'] = np.mean(hit) * 100
|
|
if 'category' in data:
|
|
cates = list(set(data['category']))
|
|
cates.sort()
|
|
for c in cates:
|
|
sub = [r for l, r in zip(lines, res) if l['category'] == c]
|
|
# [np.mean(x['match']) >= full_score_weight for x in sub]
|
|
hit = hit_calculate(sub, dataset_name)
|
|
ret[c] = np.mean(hit) * 100
|
|
ret = d2df(ret)
|
|
ret.round(2)
|
|
|
|
suffix = eval_file.split('.')[-1]
|
|
result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
|
logger.info(f'VQA Eval Finished. Saved to {result_file}. ')
|
|
logger.info(ret)
|
|
dump(ret, result_file)
|