Files
MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/utils/tablevqabench.py
2025-01-21 15:34:54 +08:00

501 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Copied from https://github.com/allenai/allennlp-semparse
Modified from https://github.com/naver-ai/tablevqabench
"""
import re
import unicodedata
import time
from abc import ABCMeta, abstractmethod
from math import isinf, isnan
# Vision Prompts
VWTQ_PROMPT = (
'You are asked to answer questions asked on an image.\n'
'You should answer the question with a single word.\n'
'Example: \n'
'Question: what was the only year mr. wu competed in the olympic games?\n'
'Answer: 2004\n'
'Question: which township in pope county, arkansas has the least amount of water area?\n'
'Answer: Freeman\n'
'If you have multiple answers, please separate them with || marks. Example: Apple||Banana||Tomato\n\n'
'Question: {question}\n'
'Answer:'
)
VTABFACT_PROMPT = (
'You are asked to answer whether the statement is True or False based on given image\n'
'You should only answer True or False.\n'
'Example: \n'
'Statement: the milwaukee buck win 6 game in the 2010 - 11 season\n'
'Answer: True\n'
'Statement: only the top team score above the average of 8.8\n'
'Answer: False\n\n'
'Statement: {question}\n'
'Answer:'
)
FINTABNETQA_PROMPT = (
'You are asked to answer questions asked on a image.\n'
'You should answer the question within a single word or few words.\n'
'If units can be known, the answer should include units such as $, %, million and etc.\n'
'Example: \n'
'Question: What were the total financing originations for the fiscal year ended October 31, 2004?\n'
'Answer: $3,852 million\n'
'Question: What is the time period represented in the table?\n'
'Answer: October 31\n'
'Question: What was the percentage of net sales for selling, general and administrative expenses in 2006?\n'
'Answer: 34.2%\n'
'Question: {question}\n'
'Answer:'
)
def evaluate_tabfact(data, score_keys):
num_examples = 0
num_correct = 0
manual_check = 0
start_time = time.time()
for instance in data:
if instance['prediction'] is None:
instance['prediction'] = 'none'
pred = instance['prediction'].lower()
gt = instance['answer']
num_examples += 1
if 'true' in pred and 'false' in pred:
manual_check += 1
score = None
elif 'true' in pred and gt == '1':
num_correct += 1
score = 1
elif 'false' in pred and gt == '0':
num_correct += 1
score = 1
else:
score = 0
instance['scores'] = {score_keys[0]: score}
if manual_check > 0:
print(f'the number of not properly parsed samples: {manual_check}')
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': [score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [Accuracy],
}
return meta
def evaluate_wtq(data, score_keys):
num_examples = 0
num_correct = 0
start_time = time.time()
for instance in data:
pred = instance['prediction'].replace('||', '|')
gt = instance['answer']
original_strings = tsv_unescape_list(gt)
target_values = to_value_list(original_strings)
predicted_strings = tsv_unescape_list(pred)
predicted_values = to_value_list(predicted_strings)
correct = check_denotation(target_values, predicted_values)
num_examples += 1
score = 0
if correct:
num_correct += 1
score = 1
instance['scores'] = {score_keys[0]: score}
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': [score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [Accuracy],
}
return meta
def evaluate_fintabnet(data, score_keys):
num_examples = 0
num_correct, _num_correct = 0, 0
start_time = time.time()
for instance in data:
pred, preds = fintabnet_normalize(instance['prediction'])
gt, gts = fintabnet_normalize(instance['answer'])
correct = 1 if gt == pred else 0
_correct = any(_pred == _gt for _pred in preds for _gt in gts)
num_examples += 1
score, _score = 0, 0
if correct:
num_correct += 1
score = 1
if _correct:
_num_correct += 1
_score = 1
instance['scores'] = {score_keys[0]: _score, 'exact_score': score}
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
_Accuracy = round((_num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': ['relieved_accuracy', score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [_Accuracy, Accuracy],
}
return meta
def fintabnet_normalize(s):
s = normalize(s)
remove_words = [
'dollar', 'gallons', 'square feet', 'shares', 'mbtu',
'mbpd', 'mbbls', 'mmbtu', 'unit', 'gwh', 'year', 'mmcf', 'mile', 'mboe'
]
# Data specific filtering using regular expressions
# Remove special characters like $, (, and )
s = re.sub(r'[\$\(\),]', '', s)
# Replace "dollar" with empty string if it's not part of another word
pattern = r'\b(' + '|'.join(remove_words) + r')s?\b'
s = re.sub(pattern, '', s, flags=re.IGNORECASE)
# Unit conversion dictionary with regex patterns for flexibility
unit_conversion = {
r' \bthousand\b': 'e3',
r' \bmillion\b': 'e6',
r' \bbillion\b': 'e9',
r'\bthousand\b': 'e3',
r'\bmillion\b': 'e6',
r'\bbillion\b': 'e9',
r' ?%': 'e-2',
}
# Convert percentages to their decimal representation.
# Applying this after unit_conversion prevents "percent" from being processed
# in cases like "million %", which would be incorrect.
# s = re.sub(r' ?%', 'e-2', s)
# s_percent = re.sub(r' ?%', '', s_percent)
s_unit_free = s
# Iterate over unit_conversion and apply transformations
for pattern, value in unit_conversion.items():
s = re.sub(pattern, value, s)
s_unit_free = re.sub(pattern, '', s_unit_free)
# Attempt to convert to float
try:
return float(s), [float(s), float(s_unit_free)]
except ValueError:
# Return the original string and the error for debugging purposes
return s, [s, s_unit_free]
def normalize(x):
if not isinstance(x, str):
x = x.decode('utf8', errors='ignore')
# Remove diacritics
x = ''.join(
c for c in unicodedata.normalize('NFKD', x) if unicodedata.category(c) != 'Mn'
)
# Normalize quotes and dashes
x = re.sub(r'[´`]', "'", x)
x = re.sub(r'[“”]', '"', x)
x = re.sub(r'[‐‑‒–—−]', '-', x)
while True:
old_x = x
# Remove citations
x = re.sub(r'((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$', '', x.strip())
# Remove details in parenthesis
x = re.sub(r'(?<!^)( \([^)]*\))*$', '', x.strip())
# Remove outermost quotation mark
x = re.sub(r'^"([^"]*)"$', r'\1', x.strip())
if x == old_x:
break
# Remove final '.'
if x and x[-1] == '.':
x = x[:-1]
# Collapse whitespaces and convert to lower case
x = re.sub(r'\s+', ' ', x, flags=re.U).lower().strip()
return x
# Value Types
class Value(object):
__metaclass__ = ABCMeta
# Should be populated with the normalized string
_normalized = None
@abstractmethod
def match(self, other):
"""Return True if the value matches the other value.
Args:
other (Value)
Returns:
a boolean
"""
pass
@property
def normalized(self):
return self._normalized
class StringValue(Value):
def __init__(self, content):
assert isinstance(content, str)
self._normalized = normalize(content)
self._hash = hash(self._normalized)
def __eq__(self, other):
return isinstance(other, StringValue) and self.normalized == other.normalized
def __hash__(self):
return self._hash
def __str__(self):
return 'S' + str([self.normalized])
def __repr__(self):
return self.__str__()
def match(self, other):
assert isinstance(other, Value)
return self.normalized == other.normalized
class NumberValue(Value):
def __init__(self, amount, original_string=None):
assert isinstance(amount, (int, float))
if abs(amount - round(amount)) < 1e-6:
self._amount = int(amount)
else:
self._amount = float(amount)
if not original_string:
self._normalized = str(self._amount)
else:
self._normalized = normalize(original_string)
self._hash = hash(self._amount)
@property
def amount(self):
return self._amount
def __eq__(self, other):
return isinstance(other, NumberValue) and self.amount == other.amount
def __hash__(self):
return self._hash
def __str__(self):
return 'N({})'.format(self.amount) + str([self.normalized])
def __repr__(self):
return self.__str__()
def match(self, other):
assert isinstance(other, Value)
if self.normalized == other.normalized:
return True
if isinstance(other, NumberValue):
return abs(self.amount - other.amount) < 1e-6
return False
@staticmethod
def parse(text):
"""Try to parse into a number.
Return:
the number (int or float) if successful; otherwise None.
"""
try:
return int(text)
except ValueError:
try:
amount = float(text)
assert not isnan(amount) and not isinf(amount)
return amount
except ValueError:
return None
class DateValue(Value):
def __init__(self, year, month, day, original_string=None):
"""Create a new DateValue. Placeholders are marked as -1."""
assert isinstance(year, int)
assert isinstance(month, int) and (month == -1 or 1 <= month <= 12)
assert isinstance(day, int) and (day == -1 or 1 <= day <= 31)
assert not (year == month == day == -1)
self._year = year
self._month = month
self._day = day
if not original_string:
self._normalized = '{}-{}-{}'.format(
year if year != -1 else 'xx',
month if month != -1 else 'xx',
day if day != '-1' else 'xx',
)
else:
self._normalized = normalize(original_string)
self._hash = hash((self._year, self._month, self._day))
@property
def ymd(self):
return (self._year, self._month, self._day)
def __eq__(self, other):
return isinstance(other, DateValue) and self.ymd == other.ymd
def __hash__(self):
return self._hash
def __str__(self):
return ('D(%d,%d,%d)' % (self._year, self._month, self._day)) + str(
[self._normalized]
)
__repr__ = __str__
def match(self, other):
assert isinstance(other, Value)
if self.normalized == other.normalized:
return True
if isinstance(other, DateValue):
return self.ymd == other.ymd
return False
@staticmethod
def parse(text):
"""Try to parse into a date.
Return:
tuple (year, month, date) if successful; otherwise None.
"""
try:
ymd = text.lower().split('-')
assert len(ymd) == 3
year = -1 if ymd[0] in ('xx', 'xxxx') else int(ymd[0])
month = -1 if ymd[1] == 'xx' else int(ymd[1])
day = -1 if ymd[2] == 'xx' else int(ymd[2])
assert not (year == month == day == -1)
assert month == -1 or 1 <= month <= 12
assert day == -1 or 1 <= day <= 31
return (year, month, day)
except:
return None
# Value Instantiation
def to_value(original_string, corenlp_value=None):
"""Convert the string to Value object.
Args:
original_string (basestring): Original string
corenlp_value (basestring): Optional value returned from CoreNLP
Returns:
Value
"""
if isinstance(original_string, Value):
# Already a Value
return original_string
if not corenlp_value:
corenlp_value = original_string
# Number?
amount = NumberValue.parse(corenlp_value)
if amount is not None:
return NumberValue(amount, original_string)
# Date?
ymd = DateValue.parse(corenlp_value)
if ymd is not None:
if ymd[1] == ymd[2] == -1:
return NumberValue(ymd[0], original_string)
else:
return DateValue(ymd[0], ymd[1], ymd[2], original_string)
# String.
return StringValue(original_string)
def to_value_list(original_strings, corenlp_values=None):
"""Convert a list of strings to a list of Values
Args:
original_strings (list[basestring])
corenlp_values (list[basestring or None])
Returns:
list[Value]
"""
assert isinstance(original_strings, (list, tuple, set))
if corenlp_values is not None:
assert isinstance(corenlp_values, (list, tuple, set))
assert len(original_strings) == len(corenlp_values)
return list(
set(to_value(x, y) for (x, y) in zip(original_strings, corenlp_values))
)
else:
return list(set(to_value(x) for x in original_strings))
# Check the Predicted Denotations
def check_denotation(target_values, predicted_values):
"""Return True if the predicted denotation is correct.
Args:
target_values (list[Value])
predicted_values (list[Value])
Returns:
bool
"""
# Check size
if len(target_values) != len(predicted_values):
return False
# Check items
for target in target_values:
if not any(target.match(pred) for pred in predicted_values):
return False
return True
# Batch Mode
def tsv_unescape(x):
"""Unescape strings in the TSV file.
Escaped characters include:
newline (0x10) -> backslash + n
vertical bar (0x7C) -> backslash + p
backslash (0x5C) -> backslash + backslash
Args:
x (str or unicode)
Returns:
a unicode
"""
return x.replace(r'\n', '\n').replace(r'\p', '|').replace('\\\\', '\\')
def tsv_unescape_list(x):
"""Unescape a list in the TSV file.
List items are joined with vertical bars (0x5C)
Args:
x (str or unicode)
Returns:
a list of unicodes
"""
return [tsv_unescape(y) for y in x.split('|')]