mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-o 2.6
This commit is contained in:
446
eval_mm/vlmevalkit/vlmeval/dataset/mmmath.py
Normal file
446
eval_mm/vlmevalkit/vlmeval/dataset/mmmath.py
Normal file
@@ -0,0 +1,446 @@
|
||||
import re
|
||||
import json
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
from sympy import simplify, Eq, sympify, Pow, pi
|
||||
from sympy.parsing.latex import parse_latex
|
||||
import sys
|
||||
import math
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from .image_base import ImageBaseDataset
|
||||
from ..utils import track_progress_rich
|
||||
from ..smp import load, dump
|
||||
|
||||
|
||||
class AutoScoringJudge:
|
||||
def __init__(self):
|
||||
# Map of special symbols to their replacements
|
||||
self.special_signal_map = {
|
||||
"\\left": "",
|
||||
"\\right": "",
|
||||
"厘米":"",
|
||||
# "∶": ":",
|
||||
",": ",",
|
||||
"$": "",
|
||||
"(":"(",
|
||||
")":")",
|
||||
"\\infty":"oo",
|
||||
"\\colon ":":",
|
||||
# "\\approx": "=",
|
||||
# "\\simeq": "=",
|
||||
# "\\sim": "=",
|
||||
# "^\\prime": "'",
|
||||
# "^{\\prime}": "'",
|
||||
"+":"+",
|
||||
"\\, ": "",
|
||||
"\\,":"",
|
||||
"^\\circ": "",
|
||||
"^{\\circ}": "",
|
||||
# "%": "",
|
||||
}
|
||||
self.pi = parse_latex("\\pi")
|
||||
# MM-Math default precision
|
||||
self.precision = 1e-2
|
||||
|
||||
def trans_greater_sign_to_interval(self, expr:str):
|
||||
expr_tmp = expr.split("<")
|
||||
return "(" + expr_tmp[0] + ", " + expr_tmp[-1] + ")"
|
||||
|
||||
def split_by_comma(self, expr: str):
|
||||
# Splits expressions by commas outside of brackets
|
||||
in_bracket_num = 0
|
||||
splitted_expr = []
|
||||
start_idx = 0
|
||||
for i, char in enumerate(expr):
|
||||
if char in ["(", "["]:
|
||||
in_bracket_num += 1
|
||||
elif char in [")", "]"]:
|
||||
in_bracket_num -= 1
|
||||
elif char == "," and in_bracket_num == 0:
|
||||
splitted_expr.append(expr[start_idx:i].strip())
|
||||
start_idx = i + 1
|
||||
|
||||
if start_idx < len(expr):
|
||||
splitted_expr.append(expr[start_idx:].strip())
|
||||
|
||||
return splitted_expr
|
||||
|
||||
def trans_plus_minus_sign(self, expr_list: list):
|
||||
# Translates plus-minus signs into separate expressions
|
||||
new_expr_list = []
|
||||
for expr in expr_list:
|
||||
if "\\pm" in expr:
|
||||
new_expr_list.append(expr.replace("\\pm", "+"))
|
||||
new_expr_list.append(expr.replace("\\pm", "-"))
|
||||
else:
|
||||
new_expr_list.append(expr)
|
||||
|
||||
return new_expr_list
|
||||
|
||||
def judge(self, expression1, expression2, precision=1e-2):
|
||||
# Judge if two expressions are equal (expression1 is considered as the Ground Truth)
|
||||
# Default precision is a list for supporting multiple expressions
|
||||
precision = precision if isinstance(precision, list) else [precision]
|
||||
|
||||
try:
|
||||
expression1, expression2 = self.preprocess(expression1, expression2)
|
||||
except:
|
||||
return False
|
||||
if expression1 == expression2:
|
||||
# print("Exactly equal")
|
||||
return True
|
||||
|
||||
# Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered
|
||||
expression1 = expression1 if re.fullmatch(r"[\u4e00-\u9fff]+", expression1) else re.sub(r'[\u4e00-\u9fff]+', '', expression1) # noqa: E501
|
||||
expression2 = expression2 if re.fullmatch(r'[\u4e00-\u9fff]+', expression2) else re.sub(r'[\u4e00-\u9fff]+', '', expression2) # noqa: E501
|
||||
# Check if two < or > in expression
|
||||
if self.is_two_greater_sign(expression1):
|
||||
expression1 = self.trans_greater_sign_to_interval(expression1)
|
||||
|
||||
if self.is_two_greater_sign(expression2):
|
||||
expression2 = self.trans_greater_sign_to_interval(expression2)
|
||||
|
||||
expression1 = self.split_by_comma(expression1)
|
||||
expression2 = self.split_by_comma(expression2)
|
||||
|
||||
temp_list1 = self.trans_plus_minus_sign(expression1)
|
||||
temp_list2 = self.trans_plus_minus_sign(expression2)
|
||||
|
||||
# Set up a list for allowed errors
|
||||
if len(precision) <= 1:
|
||||
precision = precision * len(temp_list1)
|
||||
|
||||
if len(temp_list1) != len(temp_list2):
|
||||
return False
|
||||
|
||||
# Check if elements in both lists can be paired and are equal
|
||||
idx = -1
|
||||
while len(temp_list1) != 0:
|
||||
idx = (idx + 1) % len(temp_list1)
|
||||
|
||||
item1 = temp_list1[idx]
|
||||
self.precision = precision[idx]
|
||||
|
||||
for item2 in temp_list2:
|
||||
if self.is_equal(item1, item2):
|
||||
temp_list1.remove(item1)
|
||||
temp_list2.remove(item2)
|
||||
precision.remove(self.precision)
|
||||
break
|
||||
else:
|
||||
# If no match was found, return False
|
||||
return False
|
||||
|
||||
# If all elements are matched, return True
|
||||
return True
|
||||
|
||||
def is_interval(self, expr):
|
||||
# Checks if an expression is an interval
|
||||
return expr.startswith(("(", "[")) and expr.endswith((")", "]"))
|
||||
|
||||
def is_two_greater_sign(self, expr):
|
||||
match = re.findall(r'<', expr)
|
||||
return len(match) == 2
|
||||
|
||||
def sympy_sub_pi(self, expression_sympy):
|
||||
# Replaces the symbol for pi in sympy expressions with its numerical value
|
||||
return expression_sympy.subs(self.pi, math.pi)
|
||||
|
||||
def is_equal(self, expression1, expression2):
|
||||
# Default first expression is ground truth. Check if expressions are equal in different aspects
|
||||
if expression1 == expression2 and expression1 != "" and expression2 != "":
|
||||
# print("Equivalent natively")
|
||||
return True
|
||||
|
||||
# First check if both are intervals
|
||||
if self.is_interval(expression1) and self.is_interval(expression2):
|
||||
try:
|
||||
if self.interval_equal(expression1, expression2):
|
||||
# print("Interval equivalent")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
# Then check for numerical equality
|
||||
try:
|
||||
if self.numerical_equal(expression1, expression2):
|
||||
# print("Numerically equivalent")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
# Then check if expressions are mathematically equal
|
||||
try:
|
||||
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
|
||||
# print("Expression equivalent")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Lastly, check for equation equality
|
||||
try:
|
||||
if self.equation_equal(expression1, expression2):
|
||||
# print("Equation equivalent")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
|
||||
# Check if two numerical values are equal within an allowed error range
|
||||
# Includes possible percentage cases
|
||||
reference = float(expression1)
|
||||
prediction = float(expression2)
|
||||
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
|
||||
for item in gt_result:
|
||||
if abs(item - prediction) <= self.precision * 1.01:
|
||||
return True
|
||||
return False
|
||||
|
||||
def expression_equal(self, exp1, exp2):
|
||||
# Check if two expressions are mathematically equivalent
|
||||
# Extract expression and use sympy for equivalence checking
|
||||
def extract_expression(expression):
|
||||
if "=" in expression:
|
||||
expression = expression.split("=")[1]
|
||||
return expression.strip()
|
||||
|
||||
exp1 = extract_expression(exp1)
|
||||
exp2 = extract_expression(exp2)
|
||||
|
||||
exp_too_long = len(exp1) > 300 or len(exp2) > 300
|
||||
|
||||
expr1_sym = sympify(parse_latex(exp1))
|
||||
expr2_sym = sympify(parse_latex(exp2))
|
||||
if expr1_sym == expr2_sym:
|
||||
return True
|
||||
else:
|
||||
expr1_sym = self.sympy_sub_pi(expr1_sym)
|
||||
expr2_sym = self.sympy_sub_pi(expr2_sym)
|
||||
|
||||
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or \
|
||||
(not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
|
||||
return False
|
||||
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
|
||||
try:
|
||||
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
|
||||
print("These two numbers cannot be calculated by the current computer for: "
|
||||
f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"")
|
||||
return False
|
||||
if exp_too_long:
|
||||
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
||||
return False
|
||||
if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
elif exp_too_long:
|
||||
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
simplified_expr = simplify(expr1_sym - expr2_sym)
|
||||
num_value = simplified_expr.evalf()
|
||||
return abs(num_value) < 1e-3
|
||||
except:
|
||||
return False
|
||||
|
||||
def equation_equal(self, expression1, expression2):
|
||||
# Check if two equations are mathematically equivalent
|
||||
# Simplify equations and use sympy for equivalence checking
|
||||
def simplify_equation(latex_eq):
|
||||
lhs, rhs = latex_eq.split('=')
|
||||
|
||||
lhs_expr = parse_latex(lhs)
|
||||
rhs_expr = parse_latex(rhs)
|
||||
|
||||
equation = Eq(lhs_expr, rhs_expr)
|
||||
|
||||
simplified_eq = simplify(equation.lhs - equation.rhs)
|
||||
|
||||
return simplified_eq
|
||||
|
||||
expr1_sym = simplify_equation(expression1)
|
||||
expr2_sym = simplify_equation(expression2)
|
||||
|
||||
division_result_1 = simplify(expr1_sym / expr2_sym)
|
||||
division_result_2 = simplify(expr2_sym / expr1_sym)
|
||||
|
||||
if ((division_result_1.is_Integer and division_result_1 != 0) or # noqa: W504
|
||||
(division_result_2.is_Integer and division_result_2 != 0)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def interval_equal(self, expression1, expression2):
|
||||
# Check if two intervals are mathematically equivalent
|
||||
def compare_two_interval(inter1, inter2):
|
||||
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
|
||||
return False
|
||||
|
||||
inter1 = inter1.strip('[]()')
|
||||
inter2 = inter2.strip('[]()')
|
||||
|
||||
items_1 = inter1.split(',')
|
||||
items_2 = inter2.split(',')
|
||||
|
||||
for item_1, item_2 in zip(items_1, items_2):
|
||||
if not self.expression_equal(item_1, item_2):
|
||||
return False
|
||||
return True
|
||||
|
||||
interval1 = expression1
|
||||
interval2 = expression2
|
||||
|
||||
if interval1 == interval2:
|
||||
return True
|
||||
else:
|
||||
inter_list1 = interval1.split("\\cup")
|
||||
inter_list2 = interval2.split("\\cup")
|
||||
|
||||
if len(inter_list1) != len(inter_list2):
|
||||
return False
|
||||
else:
|
||||
for inter1, inter2 in zip(inter_list1, inter_list2):
|
||||
if not compare_two_interval(inter1, inter2):
|
||||
return False
|
||||
return True
|
||||
|
||||
def preprocess(self, expression1, expression2):
|
||||
# Preprocess expressions to extract and replace special symbols
|
||||
def extract_boxed_content(latex_str):
|
||||
boxed_matches = re.finditer(r'\\boxed{', latex_str)
|
||||
results = ""
|
||||
|
||||
for match in boxed_matches:
|
||||
start_index = match.end()
|
||||
end_index = start_index
|
||||
stack = 1
|
||||
|
||||
while stack > 0 and end_index < len(latex_str):
|
||||
if latex_str[end_index] == '{':
|
||||
stack += 1
|
||||
elif latex_str[end_index] == '}':
|
||||
stack -= 1
|
||||
end_index += 1
|
||||
|
||||
if stack == 0:
|
||||
content = latex_str[start_index:end_index - 1]
|
||||
results += content + ","
|
||||
else:
|
||||
raise ValueError("Mismatched braces in LaTeX string.")
|
||||
|
||||
if results == "":
|
||||
last_line_ans = latex_str.strip().split("\n")[-1]
|
||||
dollar_pattern = r"\$(.*?)\$"
|
||||
answers = re.findall(dollar_pattern, last_line_ans)
|
||||
|
||||
if answers:
|
||||
for ans in answers:
|
||||
results += ans + ","
|
||||
else:
|
||||
results = latex_str
|
||||
|
||||
return results
|
||||
|
||||
def sepcial_symbol_replace(expression):
|
||||
|
||||
expression = expression.replace("\\text{cm}^2", '').replace("\\text{cm}", "").replace("\\,cm", '').replace("\\text{ cm}", '').replace("cm", '').replace("\\text{分米}^2", '').replace("cm^{2}", '').replace("60 \\text{ cm}^2",'').replace("\\ \\text{m}", "").replace("\\text{米}","").strip() # noqa: E501
|
||||
|
||||
expression = re.sub(r"(.+)m$", r"\1", expression)
|
||||
|
||||
if "\\in " in expression:
|
||||
expression = expression.split("\\in ")[1]
|
||||
|
||||
for signal in self.special_signal_map:
|
||||
expression = expression.replace(signal, self.special_signal_map[signal])
|
||||
|
||||
expression = re.sub(r'(\\sin|\\cos|\\tan)(\d+)', r'\1((\2/180)\\pi)', expression)
|
||||
|
||||
expression = expression.strip("\n,.:;^_=+`!@#%^&*~,。")
|
||||
|
||||
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
|
||||
expression = re.sub(pattern, r'\1', expression)
|
||||
|
||||
return expression
|
||||
|
||||
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
|
||||
|
||||
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
|
||||
|
||||
return exp1, exp2
|
||||
|
||||
def can_compute_power(self, expr):
|
||||
# Checks if a power expression can be computed
|
||||
if isinstance(expr, Pow):
|
||||
base, exp = expr.as_base_exp()
|
||||
if base.is_number and exp.is_number:
|
||||
MAX_EXP = 1000 # Adjust based on computing environment
|
||||
if abs(exp.evalf()) > MAX_EXP:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True # Not a power expression, can compute
|
||||
|
||||
|
||||
class MMMath(ImageBaseDataset):
|
||||
|
||||
TYPE = 'VQA'
|
||||
|
||||
DATASET_URL = {
|
||||
'MM-Math': 'https://opencompass.openxlab.space/utils/VLMEval/MM-Math.tsv',
|
||||
}
|
||||
DATASET_MD5 = {
|
||||
'MM-Math': '1f064ed7c4e0e8926a3fa65849419ca5',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **kwargs):
|
||||
|
||||
data = load(eval_file)
|
||||
judger = AutoScoringJudge()
|
||||
func = judger.judge
|
||||
|
||||
tups = [dict(expression1=x, expression2=y) for x, y in zip(data['answer'], data['prediction'])]
|
||||
|
||||
res = track_progress_rich(func, tups, nproc=16)
|
||||
data['hit'] = res
|
||||
dump(data, eval_file)
|
||||
|
||||
score_file = eval_file.replace('.xlsx', '_score.json')
|
||||
score = {}
|
||||
score['overall'] = np.mean(data['hit'])
|
||||
# Results by Difficulty
|
||||
difficulties = set(data['difficulty'])
|
||||
for d in difficulties:
|
||||
score[f'Difficulty-{d}'] = np.mean(data[data['difficulty'] == d]['hit'])
|
||||
|
||||
# Results by Year
|
||||
years = set(data['year'])
|
||||
for y in years:
|
||||
score[f'Year-{y}'] = np.mean(data[data['year'] == y]['hit'])
|
||||
|
||||
# Results by Knowledge-L1
|
||||
points = set(data['knowledge_l1'])
|
||||
for p in points:
|
||||
score[f'Knowledge-L1-{p}'] = np.mean(data[data['knowledge_l1'] == p]['hit'])
|
||||
|
||||
# Results by Knowledge-L2
|
||||
points = set(data['knowledge_l2'])
|
||||
for p in points:
|
||||
score[f'Knowledge-L2-{p}'] = np.mean(data[data['knowledge_l2'] == p]['hit'])
|
||||
|
||||
dump(score, score_file)
|
||||
return score
|
||||
Reference in New Issue
Block a user