Modify eval_mm for MiniCPM-o 2.6

This commit is contained in:
Poppy Xu
2025-01-21 15:34:54 +08:00
parent ec68cefc17
commit d8f382e157
82 changed files with 14279 additions and 843 deletions

View File

@@ -0,0 +1,500 @@
"""
Copied from https://github.com/allenai/allennlp-semparse
Modified from https://github.com/naver-ai/tablevqabench
"""
import re
import unicodedata
import time
from abc import ABCMeta, abstractmethod
from math import isinf, isnan
# Vision Prompts
VWTQ_PROMPT = (
'You are asked to answer questions asked on an image.\n'
'You should answer the question with a single word.\n'
'Example: \n'
'Question: what was the only year mr. wu competed in the olympic games?\n'
'Answer: 2004\n'
'Question: which township in pope county, arkansas has the least amount of water area?\n'
'Answer: Freeman\n'
'If you have multiple answers, please separate them with || marks. Example: Apple||Banana||Tomato\n\n'
'Question: {question}\n'
'Answer:'
)
VTABFACT_PROMPT = (
'You are asked to answer whether the statement is True or False based on given image\n'
'You should only answer True or False.\n'
'Example: \n'
'Statement: the milwaukee buck win 6 game in the 2010 - 11 season\n'
'Answer: True\n'
'Statement: only the top team score above the average of 8.8\n'
'Answer: False\n\n'
'Statement: {question}\n'
'Answer:'
)
FINTABNETQA_PROMPT = (
'You are asked to answer questions asked on a image.\n'
'You should answer the question within a single word or few words.\n'
'If units can be known, the answer should include units such as $, %, million and etc.\n'
'Example: \n'
'Question: What were the total financing originations for the fiscal year ended October 31, 2004?\n'
'Answer: $3,852 million\n'
'Question: What is the time period represented in the table?\n'
'Answer: October 31\n'
'Question: What was the percentage of net sales for selling, general and administrative expenses in 2006?\n'
'Answer: 34.2%\n'
'Question: {question}\n'
'Answer:'
)
def evaluate_tabfact(data, score_keys):
num_examples = 0
num_correct = 0
manual_check = 0
start_time = time.time()
for instance in data:
if instance['prediction'] is None:
instance['prediction'] = 'none'
pred = instance['prediction'].lower()
gt = instance['answer']
num_examples += 1
if 'true' in pred and 'false' in pred:
manual_check += 1
score = None
elif 'true' in pred and gt == '1':
num_correct += 1
score = 1
elif 'false' in pred and gt == '0':
num_correct += 1
score = 1
else:
score = 0
instance['scores'] = {score_keys[0]: score}
if manual_check > 0:
print(f'the number of not properly parsed samples: {manual_check}')
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': [score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [Accuracy],
}
return meta
def evaluate_wtq(data, score_keys):
num_examples = 0
num_correct = 0
start_time = time.time()
for instance in data:
pred = instance['prediction'].replace('||', '|')
gt = instance['answer']
original_strings = tsv_unescape_list(gt)
target_values = to_value_list(original_strings)
predicted_strings = tsv_unescape_list(pred)
predicted_values = to_value_list(predicted_strings)
correct = check_denotation(target_values, predicted_values)
num_examples += 1
score = 0
if correct:
num_correct += 1
score = 1
instance['scores'] = {score_keys[0]: score}
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': [score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [Accuracy],
}
return meta
def evaluate_fintabnet(data, score_keys):
num_examples = 0
num_correct, _num_correct = 0, 0
start_time = time.time()
for instance in data:
pred, preds = fintabnet_normalize(instance['prediction'])
gt, gts = fintabnet_normalize(instance['answer'])
correct = 1 if gt == pred else 0
_correct = any(_pred == _gt for _pred in preds for _gt in gts)
num_examples += 1
score, _score = 0, 0
if correct:
num_correct += 1
score = 1
if _correct:
_num_correct += 1
_score = 1
instance['scores'] = {score_keys[0]: _score, 'exact_score': score}
end_time = time.time()
elapsed_time = end_time - start_time
Accuracy = round((num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
_Accuracy = round((_num_correct + 1e-9) / (num_examples + 1e-9), 8) * 100
meta = {
'evaluators': 'correctness',
'score_info': ['relieved_accuracy', score_keys[0]],
'evaluated_time': elapsed_time,
'total_num_sample': len(data),
'average_scores': [_Accuracy, Accuracy],
}
return meta
def fintabnet_normalize(s):
s = normalize(s)
remove_words = [
'dollar', 'gallons', 'square feet', 'shares', 'mbtu',
'mbpd', 'mbbls', 'mmbtu', 'unit', 'gwh', 'year', 'mmcf', 'mile', 'mboe'
]
# Data specific filtering using regular expressions
# Remove special characters like $, (, and )
s = re.sub(r'[\$\(\),]', '', s)
# Replace "dollar" with empty string if it's not part of another word
pattern = r'\b(' + '|'.join(remove_words) + r')s?\b'
s = re.sub(pattern, '', s, flags=re.IGNORECASE)
# Unit conversion dictionary with regex patterns for flexibility
unit_conversion = {
r' \bthousand\b': 'e3',
r' \bmillion\b': 'e6',
r' \bbillion\b': 'e9',
r'\bthousand\b': 'e3',
r'\bmillion\b': 'e6',
r'\bbillion\b': 'e9',
r' ?%': 'e-2',
}
# Convert percentages to their decimal representation.
# Applying this after unit_conversion prevents "percent" from being processed
# in cases like "million %", which would be incorrect.
# s = re.sub(r' ?%', 'e-2', s)
# s_percent = re.sub(r' ?%', '', s_percent)
s_unit_free = s
# Iterate over unit_conversion and apply transformations
for pattern, value in unit_conversion.items():
s = re.sub(pattern, value, s)
s_unit_free = re.sub(pattern, '', s_unit_free)
# Attempt to convert to float
try:
return float(s), [float(s), float(s_unit_free)]
except ValueError:
# Return the original string and the error for debugging purposes
return s, [s, s_unit_free]
def normalize(x):
if not isinstance(x, str):
x = x.decode('utf8', errors='ignore')
# Remove diacritics
x = ''.join(
c for c in unicodedata.normalize('NFKD', x) if unicodedata.category(c) != 'Mn'
)
# Normalize quotes and dashes
x = re.sub(r'[´`]', "'", x)
x = re.sub(r'[“”]', '"', x)
x = re.sub(r'[‐‑‒–—−]', '-', x)
while True:
old_x = x
# Remove citations
x = re.sub(r'((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$', '', x.strip())
# Remove details in parenthesis
x = re.sub(r'(?<!^)( \([^)]*\))*$', '', x.strip())
# Remove outermost quotation mark
x = re.sub(r'^"([^"]*)"$', r'\1', x.strip())
if x == old_x:
break
# Remove final '.'
if x and x[-1] == '.':
x = x[:-1]
# Collapse whitespaces and convert to lower case
x = re.sub(r'\s+', ' ', x, flags=re.U).lower().strip()
return x
# Value Types
class Value(object):
__metaclass__ = ABCMeta
# Should be populated with the normalized string
_normalized = None
@abstractmethod
def match(self, other):
"""Return True if the value matches the other value.
Args:
other (Value)
Returns:
a boolean
"""
pass
@property
def normalized(self):
return self._normalized
class StringValue(Value):
def __init__(self, content):
assert isinstance(content, str)
self._normalized = normalize(content)
self._hash = hash(self._normalized)
def __eq__(self, other):
return isinstance(other, StringValue) and self.normalized == other.normalized
def __hash__(self):
return self._hash
def __str__(self):
return 'S' + str([self.normalized])
def __repr__(self):
return self.__str__()
def match(self, other):
assert isinstance(other, Value)
return self.normalized == other.normalized
class NumberValue(Value):
def __init__(self, amount, original_string=None):
assert isinstance(amount, (int, float))
if abs(amount - round(amount)) < 1e-6:
self._amount = int(amount)
else:
self._amount = float(amount)
if not original_string:
self._normalized = str(self._amount)
else:
self._normalized = normalize(original_string)
self._hash = hash(self._amount)
@property
def amount(self):
return self._amount
def __eq__(self, other):
return isinstance(other, NumberValue) and self.amount == other.amount
def __hash__(self):
return self._hash
def __str__(self):
return 'N({})'.format(self.amount) + str([self.normalized])
def __repr__(self):
return self.__str__()
def match(self, other):
assert isinstance(other, Value)
if self.normalized == other.normalized:
return True
if isinstance(other, NumberValue):
return abs(self.amount - other.amount) < 1e-6
return False
@staticmethod
def parse(text):
"""Try to parse into a number.
Return:
the number (int or float) if successful; otherwise None.
"""
try:
return int(text)
except ValueError:
try:
amount = float(text)
assert not isnan(amount) and not isinf(amount)
return amount
except ValueError:
return None
class DateValue(Value):
def __init__(self, year, month, day, original_string=None):
"""Create a new DateValue. Placeholders are marked as -1."""
assert isinstance(year, int)
assert isinstance(month, int) and (month == -1 or 1 <= month <= 12)
assert isinstance(day, int) and (day == -1 or 1 <= day <= 31)
assert not (year == month == day == -1)
self._year = year
self._month = month
self._day = day
if not original_string:
self._normalized = '{}-{}-{}'.format(
year if year != -1 else 'xx',
month if month != -1 else 'xx',
day if day != '-1' else 'xx',
)
else:
self._normalized = normalize(original_string)
self._hash = hash((self._year, self._month, self._day))
@property
def ymd(self):
return (self._year, self._month, self._day)
def __eq__(self, other):
return isinstance(other, DateValue) and self.ymd == other.ymd
def __hash__(self):
return self._hash
def __str__(self):
return ('D(%d,%d,%d)' % (self._year, self._month, self._day)) + str(
[self._normalized]
)
__repr__ = __str__
def match(self, other):
assert isinstance(other, Value)
if self.normalized == other.normalized:
return True
if isinstance(other, DateValue):
return self.ymd == other.ymd
return False
@staticmethod
def parse(text):
"""Try to parse into a date.
Return:
tuple (year, month, date) if successful; otherwise None.
"""
try:
ymd = text.lower().split('-')
assert len(ymd) == 3
year = -1 if ymd[0] in ('xx', 'xxxx') else int(ymd[0])
month = -1 if ymd[1] == 'xx' else int(ymd[1])
day = -1 if ymd[2] == 'xx' else int(ymd[2])
assert not (year == month == day == -1)
assert month == -1 or 1 <= month <= 12
assert day == -1 or 1 <= day <= 31
return (year, month, day)
except:
return None
# Value Instantiation
def to_value(original_string, corenlp_value=None):
"""Convert the string to Value object.
Args:
original_string (basestring): Original string
corenlp_value (basestring): Optional value returned from CoreNLP
Returns:
Value
"""
if isinstance(original_string, Value):
# Already a Value
return original_string
if not corenlp_value:
corenlp_value = original_string
# Number?
amount = NumberValue.parse(corenlp_value)
if amount is not None:
return NumberValue(amount, original_string)
# Date?
ymd = DateValue.parse(corenlp_value)
if ymd is not None:
if ymd[1] == ymd[2] == -1:
return NumberValue(ymd[0], original_string)
else:
return DateValue(ymd[0], ymd[1], ymd[2], original_string)
# String.
return StringValue(original_string)
def to_value_list(original_strings, corenlp_values=None):
"""Convert a list of strings to a list of Values
Args:
original_strings (list[basestring])
corenlp_values (list[basestring or None])
Returns:
list[Value]
"""
assert isinstance(original_strings, (list, tuple, set))
if corenlp_values is not None:
assert isinstance(corenlp_values, (list, tuple, set))
assert len(original_strings) == len(corenlp_values)
return list(
set(to_value(x, y) for (x, y) in zip(original_strings, corenlp_values))
)
else:
return list(set(to_value(x) for x in original_strings))
# Check the Predicted Denotations
def check_denotation(target_values, predicted_values):
"""Return True if the predicted denotation is correct.
Args:
target_values (list[Value])
predicted_values (list[Value])
Returns:
bool
"""
# Check size
if len(target_values) != len(predicted_values):
return False
# Check items
for target in target_values:
if not any(target.match(pred) for pred in predicted_values):
return False
return True
# Batch Mode
def tsv_unescape(x):
"""Unescape strings in the TSV file.
Escaped characters include:
newline (0x10) -> backslash + n
vertical bar (0x7C) -> backslash + p
backslash (0x5C) -> backslash + backslash
Args:
x (str or unicode)
Returns:
a unicode
"""
return x.replace(r'\n', '\n').replace(r'\p', '|').replace('\\\\', '\\')
def tsv_unescape_list(x):
"""Unescape a list in the TSV file.
List items are joined with vertical bars (0x5C)
Args:
x (str or unicode)
Returns:
a list of unicodes
"""
return [tsv_unescape(y) for y in x.split('|')]