mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 02:09:20 +08:00
447 lines
16 KiB
Python
447 lines
16 KiB
Python
import re
|
||
import json
|
||
import sympy as sp
|
||
import numpy as np
|
||
from sympy import simplify, Eq, sympify, Pow, pi
|
||
from sympy.parsing.latex import parse_latex
|
||
import sys
|
||
import math
|
||
import os
|
||
import argparse
|
||
|
||
from .image_base import ImageBaseDataset
|
||
from ..utils import track_progress_rich
|
||
from ..smp import load, dump
|
||
|
||
|
||
class AutoScoringJudge:
|
||
def __init__(self):
|
||
# Map of special symbols to their replacements
|
||
self.special_signal_map = {
|
||
"\\left": "",
|
||
"\\right": "",
|
||
"厘米":"",
|
||
# "∶": ":",
|
||
",": ",",
|
||
"$": "",
|
||
"(":"(",
|
||
")":")",
|
||
"\\infty":"oo",
|
||
"\\colon ":":",
|
||
# "\\approx": "=",
|
||
# "\\simeq": "=",
|
||
# "\\sim": "=",
|
||
# "^\\prime": "'",
|
||
# "^{\\prime}": "'",
|
||
"+":"+",
|
||
"\\, ": "",
|
||
"\\,":"",
|
||
"^\\circ": "",
|
||
"^{\\circ}": "",
|
||
# "%": "",
|
||
}
|
||
self.pi = parse_latex("\\pi")
|
||
# MM-Math default precision
|
||
self.precision = 1e-2
|
||
|
||
def trans_greater_sign_to_interval(self, expr:str):
|
||
expr_tmp = expr.split("<")
|
||
return "(" + expr_tmp[0] + ", " + expr_tmp[-1] + ")"
|
||
|
||
def split_by_comma(self, expr: str):
|
||
# Splits expressions by commas outside of brackets
|
||
in_bracket_num = 0
|
||
splitted_expr = []
|
||
start_idx = 0
|
||
for i, char in enumerate(expr):
|
||
if char in ["(", "["]:
|
||
in_bracket_num += 1
|
||
elif char in [")", "]"]:
|
||
in_bracket_num -= 1
|
||
elif char == "," and in_bracket_num == 0:
|
||
splitted_expr.append(expr[start_idx:i].strip())
|
||
start_idx = i + 1
|
||
|
||
if start_idx < len(expr):
|
||
splitted_expr.append(expr[start_idx:].strip())
|
||
|
||
return splitted_expr
|
||
|
||
def trans_plus_minus_sign(self, expr_list: list):
|
||
# Translates plus-minus signs into separate expressions
|
||
new_expr_list = []
|
||
for expr in expr_list:
|
||
if "\\pm" in expr:
|
||
new_expr_list.append(expr.replace("\\pm", "+"))
|
||
new_expr_list.append(expr.replace("\\pm", "-"))
|
||
else:
|
||
new_expr_list.append(expr)
|
||
|
||
return new_expr_list
|
||
|
||
def judge(self, expression1, expression2, precision=1e-2):
|
||
# Judge if two expressions are equal (expression1 is considered as the Ground Truth)
|
||
# Default precision is a list for supporting multiple expressions
|
||
precision = precision if isinstance(precision, list) else [precision]
|
||
|
||
try:
|
||
expression1, expression2 = self.preprocess(expression1, expression2)
|
||
except:
|
||
return False
|
||
if expression1 == expression2:
|
||
# print("Exactly equal")
|
||
return True
|
||
|
||
# Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered
|
||
expression1 = expression1 if re.fullmatch(r"[\u4e00-\u9fff]+", expression1) else re.sub(r'[\u4e00-\u9fff]+', '', expression1) # noqa: E501
|
||
expression2 = expression2 if re.fullmatch(r'[\u4e00-\u9fff]+', expression2) else re.sub(r'[\u4e00-\u9fff]+', '', expression2) # noqa: E501
|
||
# Check if two < or > in expression
|
||
if self.is_two_greater_sign(expression1):
|
||
expression1 = self.trans_greater_sign_to_interval(expression1)
|
||
|
||
if self.is_two_greater_sign(expression2):
|
||
expression2 = self.trans_greater_sign_to_interval(expression2)
|
||
|
||
expression1 = self.split_by_comma(expression1)
|
||
expression2 = self.split_by_comma(expression2)
|
||
|
||
temp_list1 = self.trans_plus_minus_sign(expression1)
|
||
temp_list2 = self.trans_plus_minus_sign(expression2)
|
||
|
||
# Set up a list for allowed errors
|
||
if len(precision) <= 1:
|
||
precision = precision * len(temp_list1)
|
||
|
||
if len(temp_list1) != len(temp_list2):
|
||
return False
|
||
|
||
# Check if elements in both lists can be paired and are equal
|
||
idx = -1
|
||
while len(temp_list1) != 0:
|
||
idx = (idx + 1) % len(temp_list1)
|
||
|
||
item1 = temp_list1[idx]
|
||
self.precision = precision[idx]
|
||
|
||
for item2 in temp_list2:
|
||
if self.is_equal(item1, item2):
|
||
temp_list1.remove(item1)
|
||
temp_list2.remove(item2)
|
||
precision.remove(self.precision)
|
||
break
|
||
else:
|
||
# If no match was found, return False
|
||
return False
|
||
|
||
# If all elements are matched, return True
|
||
return True
|
||
|
||
def is_interval(self, expr):
|
||
# Checks if an expression is an interval
|
||
return expr.startswith(("(", "[")) and expr.endswith((")", "]"))
|
||
|
||
def is_two_greater_sign(self, expr):
|
||
match = re.findall(r'<', expr)
|
||
return len(match) == 2
|
||
|
||
def sympy_sub_pi(self, expression_sympy):
|
||
# Replaces the symbol for pi in sympy expressions with its numerical value
|
||
return expression_sympy.subs(self.pi, math.pi)
|
||
|
||
def is_equal(self, expression1, expression2):
|
||
# Default first expression is ground truth. Check if expressions are equal in different aspects
|
||
if expression1 == expression2 and expression1 != "" and expression2 != "":
|
||
# print("Equivalent natively")
|
||
return True
|
||
|
||
# First check if both are intervals
|
||
if self.is_interval(expression1) and self.is_interval(expression2):
|
||
try:
|
||
if self.interval_equal(expression1, expression2):
|
||
# print("Interval equivalent")
|
||
return True
|
||
except:
|
||
return False
|
||
|
||
# Then check for numerical equality
|
||
try:
|
||
if self.numerical_equal(expression1, expression2):
|
||
# print("Numerically equivalent")
|
||
return True
|
||
except:
|
||
pass
|
||
# Then check if expressions are mathematically equal
|
||
try:
|
||
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
|
||
# print("Expression equivalent")
|
||
return True
|
||
except:
|
||
pass
|
||
|
||
# Lastly, check for equation equality
|
||
try:
|
||
if self.equation_equal(expression1, expression2):
|
||
# print("Equation equivalent")
|
||
return True
|
||
except:
|
||
pass
|
||
|
||
return False
|
||
|
||
def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
|
||
# Check if two numerical values are equal within an allowed error range
|
||
# Includes possible percentage cases
|
||
reference = float(expression1)
|
||
prediction = float(expression2)
|
||
|
||
if include_percentage:
|
||
gt_result = [reference / 100, reference, reference * 100]
|
||
else:
|
||
gt_result = [reference]
|
||
|
||
for item in gt_result:
|
||
if abs(item - prediction) <= self.precision * 1.01:
|
||
return True
|
||
return False
|
||
|
||
def expression_equal(self, exp1, exp2):
|
||
# Check if two expressions are mathematically equivalent
|
||
# Extract expression and use sympy for equivalence checking
|
||
def extract_expression(expression):
|
||
if "=" in expression:
|
||
expression = expression.split("=")[1]
|
||
return expression.strip()
|
||
|
||
exp1 = extract_expression(exp1)
|
||
exp2 = extract_expression(exp2)
|
||
|
||
exp_too_long = len(exp1) > 300 or len(exp2) > 300
|
||
|
||
expr1_sym = sympify(parse_latex(exp1))
|
||
expr2_sym = sympify(parse_latex(exp2))
|
||
if expr1_sym == expr2_sym:
|
||
return True
|
||
else:
|
||
expr1_sym = self.sympy_sub_pi(expr1_sym)
|
||
expr2_sym = self.sympy_sub_pi(expr2_sym)
|
||
|
||
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or \
|
||
(not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
|
||
return False
|
||
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
|
||
try:
|
||
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
|
||
print("These two numbers cannot be calculated by the current computer for: "
|
||
f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"")
|
||
return False
|
||
if exp_too_long:
|
||
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
||
return False
|
||
if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
|
||
return True
|
||
else:
|
||
return False
|
||
except:
|
||
return False
|
||
elif exp_too_long:
|
||
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
||
return False
|
||
else:
|
||
try:
|
||
simplified_expr = simplify(expr1_sym - expr2_sym)
|
||
num_value = simplified_expr.evalf()
|
||
return abs(num_value) < 1e-3
|
||
except:
|
||
return False
|
||
|
||
def equation_equal(self, expression1, expression2):
|
||
# Check if two equations are mathematically equivalent
|
||
# Simplify equations and use sympy for equivalence checking
|
||
def simplify_equation(latex_eq):
|
||
lhs, rhs = latex_eq.split('=')
|
||
|
||
lhs_expr = parse_latex(lhs)
|
||
rhs_expr = parse_latex(rhs)
|
||
|
||
equation = Eq(lhs_expr, rhs_expr)
|
||
|
||
simplified_eq = simplify(equation.lhs - equation.rhs)
|
||
|
||
return simplified_eq
|
||
|
||
expr1_sym = simplify_equation(expression1)
|
||
expr2_sym = simplify_equation(expression2)
|
||
|
||
division_result_1 = simplify(expr1_sym / expr2_sym)
|
||
division_result_2 = simplify(expr2_sym / expr1_sym)
|
||
|
||
if ((division_result_1.is_Integer and division_result_1 != 0) or # noqa: W504
|
||
(division_result_2.is_Integer and division_result_2 != 0)):
|
||
return True
|
||
else:
|
||
return False
|
||
|
||
def interval_equal(self, expression1, expression2):
|
||
# Check if two intervals are mathematically equivalent
|
||
def compare_two_interval(inter1, inter2):
|
||
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
|
||
return False
|
||
|
||
inter1 = inter1.strip('[]()')
|
||
inter2 = inter2.strip('[]()')
|
||
|
||
items_1 = inter1.split(',')
|
||
items_2 = inter2.split(',')
|
||
|
||
for item_1, item_2 in zip(items_1, items_2):
|
||
if not self.expression_equal(item_1, item_2):
|
||
return False
|
||
return True
|
||
|
||
interval1 = expression1
|
||
interval2 = expression2
|
||
|
||
if interval1 == interval2:
|
||
return True
|
||
else:
|
||
inter_list1 = interval1.split("\\cup")
|
||
inter_list2 = interval2.split("\\cup")
|
||
|
||
if len(inter_list1) != len(inter_list2):
|
||
return False
|
||
else:
|
||
for inter1, inter2 in zip(inter_list1, inter_list2):
|
||
if not compare_two_interval(inter1, inter2):
|
||
return False
|
||
return True
|
||
|
||
def preprocess(self, expression1, expression2):
|
||
# Preprocess expressions to extract and replace special symbols
|
||
def extract_boxed_content(latex_str):
|
||
boxed_matches = re.finditer(r'\\boxed{', latex_str)
|
||
results = ""
|
||
|
||
for match in boxed_matches:
|
||
start_index = match.end()
|
||
end_index = start_index
|
||
stack = 1
|
||
|
||
while stack > 0 and end_index < len(latex_str):
|
||
if latex_str[end_index] == '{':
|
||
stack += 1
|
||
elif latex_str[end_index] == '}':
|
||
stack -= 1
|
||
end_index += 1
|
||
|
||
if stack == 0:
|
||
content = latex_str[start_index:end_index - 1]
|
||
results += content + ","
|
||
else:
|
||
raise ValueError("Mismatched braces in LaTeX string.")
|
||
|
||
if results == "":
|
||
last_line_ans = latex_str.strip().split("\n")[-1]
|
||
dollar_pattern = r"\$(.*?)\$"
|
||
answers = re.findall(dollar_pattern, last_line_ans)
|
||
|
||
if answers:
|
||
for ans in answers:
|
||
results += ans + ","
|
||
else:
|
||
results = latex_str
|
||
|
||
return results
|
||
|
||
def sepcial_symbol_replace(expression):
|
||
|
||
expression = expression.replace("\\text{cm}^2", '').replace("\\text{cm}", "").replace("\\,cm", '').replace("\\text{ cm}", '').replace("cm", '').replace("\\text{分米}^2", '').replace("cm^{2}", '').replace("60 \\text{ cm}^2",'').replace("\\ \\text{m}", "").replace("\\text{米}","").strip() # noqa: E501
|
||
|
||
expression = re.sub(r"(.+)m$", r"\1", expression)
|
||
|
||
if "\\in " in expression:
|
||
expression = expression.split("\\in ")[1]
|
||
|
||
for signal in self.special_signal_map:
|
||
expression = expression.replace(signal, self.special_signal_map[signal])
|
||
|
||
expression = re.sub(r'(\\sin|\\cos|\\tan)(\d+)', r'\1((\2/180)\\pi)', expression)
|
||
|
||
expression = expression.strip("\n,.:;^_=+`!@#%^&*~,。")
|
||
|
||
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
|
||
expression = re.sub(pattern, r'\1', expression)
|
||
|
||
return expression
|
||
|
||
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
|
||
|
||
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
|
||
|
||
return exp1, exp2
|
||
|
||
def can_compute_power(self, expr):
|
||
# Checks if a power expression can be computed
|
||
if isinstance(expr, Pow):
|
||
base, exp = expr.as_base_exp()
|
||
if base.is_number and exp.is_number:
|
||
MAX_EXP = 1000 # Adjust based on computing environment
|
||
if abs(exp.evalf()) > MAX_EXP:
|
||
return False
|
||
else:
|
||
return True
|
||
else:
|
||
return False
|
||
else:
|
||
return True # Not a power expression, can compute
|
||
|
||
|
||
class MMMath(ImageBaseDataset):
|
||
|
||
TYPE = 'VQA'
|
||
|
||
DATASET_URL = {
|
||
'MM-Math': 'https://opencompass.openxlab.space/utils/VLMEval/MM-Math.tsv',
|
||
}
|
||
DATASET_MD5 = {
|
||
'MM-Math': '1f064ed7c4e0e8926a3fa65849419ca5',
|
||
}
|
||
|
||
@classmethod
|
||
def evaluate(self, eval_file, **kwargs):
|
||
|
||
data = load(eval_file)
|
||
judger = AutoScoringJudge()
|
||
func = judger.judge
|
||
|
||
tups = [dict(expression1=x, expression2=y) for x, y in zip(data['answer'], data['prediction'])]
|
||
|
||
res = track_progress_rich(func, tups, nproc=16)
|
||
data['hit'] = res
|
||
dump(data, eval_file)
|
||
|
||
score_file = eval_file.replace('.xlsx', '_score.json')
|
||
score = {}
|
||
score['overall'] = np.mean(data['hit'])
|
||
# Results by Difficulty
|
||
difficulties = set(data['difficulty'])
|
||
for d in difficulties:
|
||
score[f'Difficulty-{d}'] = np.mean(data[data['difficulty'] == d]['hit'])
|
||
|
||
# Results by Year
|
||
years = set(data['year'])
|
||
for y in years:
|
||
score[f'Year-{y}'] = np.mean(data[data['year'] == y]['hit'])
|
||
|
||
# Results by Knowledge-L1
|
||
points = set(data['knowledge_l1'])
|
||
for p in points:
|
||
score[f'Knowledge-L1-{p}'] = np.mean(data[data['knowledge_l1'] == p]['hit'])
|
||
|
||
# Results by Knowledge-L2
|
||
points = set(data['knowledge_l2'])
|
||
for p in points:
|
||
score[f'Knowledge-L2-{p}'] = np.mean(data[data['knowledge_l2'] == p]['hit'])
|
||
|
||
dump(score, score_file)
|
||
return score
|