mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
501 lines
15 KiB
Python
501 lines
15 KiB
Python
"""
|
||
Copied from https://github.com/allenai/allennlp-semparse
|
||
Modified from https://github.com/naver-ai/tablevqabench
|
||
"""
|
||
|
||
import re
|
||
import unicodedata
|
||
import time
|
||
|
||
from abc import ABCMeta, abstractmethod
|
||
from math import isinf, isnan
|
||
|
||
|
||
# Vision Prompts
|
||
VWTQ_PROMPT = (
|
||
'You are asked to answer questions asked on an image.\n'
|
||
'You should answer the question with a single word.\n'
|
||
'Example: \n'
|
||
'Question: what was the only year mr. wu competed in the olympic games?\n'
|
||
'Answer: 2004\n'
|
||
'Question: which township in pope county, arkansas has the least amount of water area?\n'
|
||
'Answer: Freeman\n'
|
||
'If you have multiple answers, please separate them with || marks. Example: Apple||Banana||Tomato\n\n'
|
||
'Question: {question}\n'
|
||
'Answer:'
|
||
)
|
||
|
||
VTABFACT_PROMPT = (
|
||
'You are asked to answer whether the statement is True or False based on given image\n'
|
||
'You should only answer True or False.\n'
|
||
'Example: \n'
|
||
'Statement: the milwaukee buck win 6 game in the 2010 - 11 season\n'
|
||
'Answer: True\n'
|
||
'Statement: only the top team score above the average of 8.8\n'
|
||
'Answer: False\n\n'
|
||
'Statement: {question}\n'
|
||
'Answer:'
|
||
)
|
||
|
||
FINTABNETQA_PROMPT = (
|
||
'You are asked to answer questions asked on a image.\n'
|
||
'You should answer the question within a single word or few words.\n'
|
||
'If units can be known, the answer should include units such as $, %, million and etc.\n'
|
||
'Example: \n'
|
||
'Question: What were the total financing originations for the fiscal year ended October 31, 2004?\n'
|
||
'Answer: $3,852 million\n'
|
||
'Question: What is the time period represented in the table?\n'
|
||
'Answer: October 31\n'
|
||
'Question: What was the percentage of net sales for selling, general and administrative expenses in 2006?\n'
|
||
'Answer: 34.2%\n'
|
||
'Question: {question}\n'
|
||
'Answer:'
|
||
)
|
||
|
||
|
||
def evaluate_tabfact(data, score_keys):
|
||
num_examples = 0
|
||
num_correct = 0
|
||
manual_check = 0
|
||
start_time = time.time()
|
||
for instance in data:
|
||
if instance['prediction'] is None:
|
||
instance['prediction'] = 'none'
|
||
pred = instance['prediction'].lower()
|
||
gt = instance['answer']
|
||
num_examples += 1
|
||
if 'true' in pred and 'false' in pred:
|
||
manual_check += 1
|
||
score = None
|
||
elif 'true' in pred and gt == '1':
|
||
num_correct += 1
|
||
score = 1
|
||
elif 'false' in pred and gt == '0':
|
||
num_correct += 1
|
||
score = 1
|
||
else:
|
||
score = 0
|
||
instance['scores'] = {score_keys[0]: score}
|
||
if manual_check > 0:
|
||
print(f'the number of not properly parsed samples: {manual_check}')
|
||
end_time = time.time()
|
||
elapsed_time = end_time - start_time
|
||
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
|
||
meta = {
|
||
'evaluators': 'correctness',
|
||
'score_info': [score_keys[0]],
|
||
'evaluated_time': elapsed_time,
|
||
'total_num_sample': len(data),
|
||
'average_scores': [Accuracy],
|
||
}
|
||
return meta
|
||
|
||
|
||
def evaluate_wtq(data, score_keys):
|
||
num_examples = 0
|
||
num_correct = 0
|
||
start_time = time.time()
|
||
|
||
for instance in data:
|
||
pred = instance['prediction'].replace('||', '|')
|
||
gt = instance['answer']
|
||
original_strings = tsv_unescape_list(gt)
|
||
target_values = to_value_list(original_strings)
|
||
|
||
predicted_strings = tsv_unescape_list(pred)
|
||
predicted_values = to_value_list(predicted_strings)
|
||
correct = check_denotation(target_values, predicted_values)
|
||
num_examples += 1
|
||
score = 0
|
||
if correct:
|
||
num_correct += 1
|
||
score = 1
|
||
instance['scores'] = {score_keys[0]: score}
|
||
|
||
end_time = time.time()
|
||
elapsed_time = end_time - start_time
|
||
|
||
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
|
||
meta = {
|
||
'evaluators': 'correctness',
|
||
'score_info': [score_keys[0]],
|
||
'evaluated_time': elapsed_time,
|
||
'total_num_sample': len(data),
|
||
'average_scores': [Accuracy],
|
||
}
|
||
return meta
|
||
|
||
|
||
def evaluate_fintabnet(data, score_keys):
|
||
num_examples = 0
|
||
num_correct, _num_correct = 0, 0
|
||
start_time = time.time()
|
||
for instance in data:
|
||
pred, preds = fintabnet_normalize(instance['prediction'])
|
||
gt, gts = fintabnet_normalize(instance['answer'])
|
||
correct = 1 if gt == pred else 0
|
||
_correct = any(_pred == _gt for _pred in preds for _gt in gts)
|
||
num_examples += 1
|
||
score, _score = 0, 0
|
||
if correct:
|
||
num_correct += 1
|
||
score = 1
|
||
if _correct:
|
||
_num_correct += 1
|
||
_score = 1
|
||
instance['scores'] = {score_keys[0]: _score, 'exact_score': score}
|
||
|
||
end_time = time.time()
|
||
elapsed_time = end_time - start_time
|
||
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
|
||
_Accuracy = round((_num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
|
||
meta = {
|
||
'evaluators': 'correctness',
|
||
'score_info': ['relieved_accuracy', score_keys[0]],
|
||
'evaluated_time': elapsed_time,
|
||
'total_num_sample': len(data),
|
||
'average_scores': [_Accuracy, Accuracy],
|
||
}
|
||
return meta
|
||
|
||
|
||
def fintabnet_normalize(s):
|
||
s = normalize(s)
|
||
remove_words = [
|
||
'dollar', 'gallons', 'square feet', 'shares', 'mbtu',
|
||
'mbpd', 'mbbls', 'mmbtu', 'unit', 'gwh', 'year', 'mmcf', 'mile', 'mboe'
|
||
]
|
||
|
||
# Data specific filtering using regular expressions
|
||
# Remove special characters like $, (, and )
|
||
s = re.sub(r'[\$\(\),]', '', s)
|
||
|
||
# Replace "dollar" with empty string if it's not part of another word
|
||
pattern = r'\b(' + '|'.join(remove_words) + r')s?\b'
|
||
s = re.sub(pattern, '', s, flags=re.IGNORECASE)
|
||
|
||
# Unit conversion dictionary with regex patterns for flexibility
|
||
unit_conversion = {
|
||
r' \bthousand\b': 'e3',
|
||
r' \bmillion\b': 'e6',
|
||
r' \bbillion\b': 'e9',
|
||
r'\bthousand\b': 'e3',
|
||
r'\bmillion\b': 'e6',
|
||
r'\bbillion\b': 'e9',
|
||
r' ?%': 'e-2',
|
||
}
|
||
|
||
# Convert percentages to their decimal representation.
|
||
# Applying this after unit_conversion prevents "percent" from being processed
|
||
# in cases like "million %", which would be incorrect.
|
||
# s = re.sub(r' ?%', 'e-2', s)
|
||
# s_percent = re.sub(r' ?%', '', s_percent)
|
||
|
||
s_unit_free = s
|
||
|
||
# Iterate over unit_conversion and apply transformations
|
||
for pattern, value in unit_conversion.items():
|
||
s = re.sub(pattern, value, s)
|
||
s_unit_free = re.sub(pattern, '', s_unit_free)
|
||
|
||
# Attempt to convert to float
|
||
try:
|
||
return float(s), [float(s), float(s_unit_free)]
|
||
except ValueError:
|
||
# Return the original string and the error for debugging purposes
|
||
return s, [s, s_unit_free]
|
||
|
||
|
||
def normalize(x):
|
||
if not isinstance(x, str):
|
||
x = x.decode('utf8', errors='ignore')
|
||
# Remove diacritics
|
||
x = ''.join(
|
||
c for c in unicodedata.normalize('NFKD', x) if unicodedata.category(c) != 'Mn'
|
||
)
|
||
# Normalize quotes and dashes
|
||
x = re.sub(r'[‘’´`]', "'", x)
|
||
x = re.sub(r'[“”]', '"', x)
|
||
x = re.sub(r'[‐‑‒–—−]', '-', x)
|
||
while True:
|
||
old_x = x
|
||
# Remove citations
|
||
x = re.sub(r'((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$', '', x.strip())
|
||
# Remove details in parenthesis
|
||
x = re.sub(r'(?<!^)( \([^)]*\))*$', '', x.strip())
|
||
# Remove outermost quotation mark
|
||
x = re.sub(r'^"([^"]*)"$', r'\1', x.strip())
|
||
if x == old_x:
|
||
break
|
||
# Remove final '.'
|
||
if x and x[-1] == '.':
|
||
x = x[:-1]
|
||
# Collapse whitespaces and convert to lower case
|
||
x = re.sub(r'\s+', ' ', x, flags=re.U).lower().strip()
|
||
return x
|
||
|
||
|
||
# Value Types
|
||
class Value(object):
|
||
__metaclass__ = ABCMeta
|
||
|
||
# Should be populated with the normalized string
|
||
_normalized = None
|
||
|
||
@abstractmethod
|
||
def match(self, other):
|
||
"""Return True if the value matches the other value.
|
||
|
||
Args:
|
||
other (Value)
|
||
Returns:
|
||
a boolean
|
||
"""
|
||
pass
|
||
|
||
@property
|
||
def normalized(self):
|
||
return self._normalized
|
||
|
||
|
||
class StringValue(Value):
|
||
def __init__(self, content):
|
||
assert isinstance(content, str)
|
||
self._normalized = normalize(content)
|
||
self._hash = hash(self._normalized)
|
||
|
||
def __eq__(self, other):
|
||
return isinstance(other, StringValue) and self.normalized == other.normalized
|
||
|
||
def __hash__(self):
|
||
return self._hash
|
||
|
||
def __str__(self):
|
||
return 'S' + str([self.normalized])
|
||
|
||
def __repr__(self):
|
||
return self.__str__()
|
||
|
||
def match(self, other):
|
||
assert isinstance(other, Value)
|
||
return self.normalized == other.normalized
|
||
|
||
|
||
class NumberValue(Value):
|
||
def __init__(self, amount, original_string=None):
|
||
assert isinstance(amount, (int, float))
|
||
if abs(amount - round(amount)) < 1e-6:
|
||
self._amount = int(amount)
|
||
else:
|
||
self._amount = float(amount)
|
||
if not original_string:
|
||
self._normalized = str(self._amount)
|
||
else:
|
||
self._normalized = normalize(original_string)
|
||
self._hash = hash(self._amount)
|
||
|
||
@property
|
||
def amount(self):
|
||
return self._amount
|
||
|
||
def __eq__(self, other):
|
||
return isinstance(other, NumberValue) and self.amount == other.amount
|
||
|
||
def __hash__(self):
|
||
return self._hash
|
||
|
||
def __str__(self):
|
||
return 'N({})'.format(self.amount) + str([self.normalized])
|
||
|
||
def __repr__(self):
|
||
return self.__str__()
|
||
|
||
def match(self, other):
|
||
assert isinstance(other, Value)
|
||
if self.normalized == other.normalized:
|
||
return True
|
||
if isinstance(other, NumberValue):
|
||
return abs(self.amount - other.amount) < 1e-6
|
||
return False
|
||
|
||
@staticmethod
|
||
def parse(text):
|
||
"""Try to parse into a number.
|
||
|
||
Return:
|
||
the number (int or float) if successful; otherwise None.
|
||
"""
|
||
try:
|
||
return int(text)
|
||
except ValueError:
|
||
try:
|
||
amount = float(text)
|
||
assert not isnan(amount) and not isinf(amount)
|
||
return amount
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
class DateValue(Value):
|
||
def __init__(self, year, month, day, original_string=None):
|
||
"""Create a new DateValue. Placeholders are marked as -1."""
|
||
assert isinstance(year, int)
|
||
assert isinstance(month, int) and (month == -1 or 1 <= month <= 12)
|
||
assert isinstance(day, int) and (day == -1 or 1 <= day <= 31)
|
||
assert not (year == month == day == -1)
|
||
self._year = year
|
||
self._month = month
|
||
self._day = day
|
||
if not original_string:
|
||
self._normalized = '{}-{}-{}'.format(
|
||
year if year != -1 else 'xx',
|
||
month if month != -1 else 'xx',
|
||
day if day != '-1' else 'xx',
|
||
)
|
||
else:
|
||
self._normalized = normalize(original_string)
|
||
self._hash = hash((self._year, self._month, self._day))
|
||
|
||
@property
|
||
def ymd(self):
|
||
return (self._year, self._month, self._day)
|
||
|
||
def __eq__(self, other):
|
||
return isinstance(other, DateValue) and self.ymd == other.ymd
|
||
|
||
def __hash__(self):
|
||
return self._hash
|
||
|
||
def __str__(self):
|
||
return ('D(%d,%d,%d)' % (self._year, self._month, self._day)) + str(
|
||
[self._normalized]
|
||
)
|
||
|
||
__repr__ = __str__
|
||
|
||
def match(self, other):
|
||
assert isinstance(other, Value)
|
||
if self.normalized == other.normalized:
|
||
return True
|
||
if isinstance(other, DateValue):
|
||
return self.ymd == other.ymd
|
||
return False
|
||
|
||
@staticmethod
|
||
def parse(text):
|
||
"""Try to parse into a date.
|
||
|
||
Return:
|
||
tuple (year, month, date) if successful; otherwise None.
|
||
"""
|
||
try:
|
||
ymd = text.lower().split('-')
|
||
assert len(ymd) == 3
|
||
year = -1 if ymd[0] in ('xx', 'xxxx') else int(ymd[0])
|
||
month = -1 if ymd[1] == 'xx' else int(ymd[1])
|
||
day = -1 if ymd[2] == 'xx' else int(ymd[2])
|
||
assert not (year == month == day == -1)
|
||
assert month == -1 or 1 <= month <= 12
|
||
assert day == -1 or 1 <= day <= 31
|
||
return (year, month, day)
|
||
except:
|
||
return None
|
||
|
||
|
||
# Value Instantiation
|
||
def to_value(original_string, corenlp_value=None):
|
||
"""Convert the string to Value object.
|
||
|
||
Args:
|
||
original_string (basestring): Original string
|
||
corenlp_value (basestring): Optional value returned from CoreNLP
|
||
Returns:
|
||
Value
|
||
"""
|
||
if isinstance(original_string, Value):
|
||
# Already a Value
|
||
return original_string
|
||
if not corenlp_value:
|
||
corenlp_value = original_string
|
||
# Number?
|
||
amount = NumberValue.parse(corenlp_value)
|
||
if amount is not None:
|
||
return NumberValue(amount, original_string)
|
||
# Date?
|
||
ymd = DateValue.parse(corenlp_value)
|
||
if ymd is not None:
|
||
if ymd[1] == ymd[2] == -1:
|
||
return NumberValue(ymd[0], original_string)
|
||
else:
|
||
return DateValue(ymd[0], ymd[1], ymd[2], original_string)
|
||
# String.
|
||
return StringValue(original_string)
|
||
|
||
|
||
def to_value_list(original_strings, corenlp_values=None):
|
||
"""Convert a list of strings to a list of Values
|
||
|
||
Args:
|
||
original_strings (list[basestring])
|
||
corenlp_values (list[basestring or None])
|
||
Returns:
|
||
list[Value]
|
||
"""
|
||
assert isinstance(original_strings, (list, tuple, set))
|
||
if corenlp_values is not None:
|
||
assert isinstance(corenlp_values, (list, tuple, set))
|
||
assert len(original_strings) == len(corenlp_values)
|
||
return list(
|
||
set(to_value(x, y) for (x, y) in zip(original_strings, corenlp_values))
|
||
)
|
||
else:
|
||
return list(set(to_value(x) for x in original_strings))
|
||
|
||
|
||
# Check the Predicted Denotations
|
||
def check_denotation(target_values, predicted_values):
|
||
"""Return True if the predicted denotation is correct.
|
||
|
||
Args:
|
||
target_values (list[Value])
|
||
predicted_values (list[Value])
|
||
Returns:
|
||
bool
|
||
"""
|
||
# Check size
|
||
if len(target_values) != len(predicted_values):
|
||
return False
|
||
# Check items
|
||
for target in target_values:
|
||
if not any(target.match(pred) for pred in predicted_values):
|
||
return False
|
||
return True
|
||
|
||
|
||
# Batch Mode
|
||
def tsv_unescape(x):
|
||
"""Unescape strings in the TSV file.
|
||
Escaped characters include:
|
||
newline (0x10) -> backslash + n
|
||
vertical bar (0x7C) -> backslash + p
|
||
backslash (0x5C) -> backslash + backslash
|
||
|
||
Args:
|
||
x (str or unicode)
|
||
Returns:
|
||
a unicode
|
||
"""
|
||
return x.replace(r'\n', '\n').replace(r'\p', '|').replace('\\\\', '\\')
|
||
|
||
|
||
def tsv_unescape_list(x):
|
||
"""Unescape a list in the TSV file.
|
||
List items are joined with vertical bars (0x5C)
|
||
|
||
Args:
|
||
x (str or unicode)
|
||
Returns:
|
||
a list of unicodes
|
||
"""
|
||
return [tsv_unescape(y) for y in x.split('|')]
|