mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-o 2.6
This commit is contained in:
532
eval_mm/vlmevalkit/vlmeval/dataset/utils/olympiadbench.py
Normal file
532
eval_mm/vlmevalkit/vlmeval/dataset/utils/olympiadbench.py
Normal file
@@ -0,0 +1,532 @@
|
||||
import re
|
||||
import json
|
||||
from math import isclose
|
||||
import sympy as sp
|
||||
from sympy import simplify, Eq, sympify, evalf, Pow
|
||||
from sympy.parsing.latex import parse_latex
|
||||
import antlr4
|
||||
from decimal import Decimal, getcontext
|
||||
from fractions import Fraction
|
||||
import sys
|
||||
import math
|
||||
|
||||
|
||||
chinese_answer_type_dict = {
|
||||
'Numerical': '数值',
|
||||
'Expression': '表达式',
|
||||
'Equation': '方程',
|
||||
'Interval': '区间'
|
||||
}
|
||||
english_answer_type_dict = {
|
||||
'Numerical': 'a numerical value',
|
||||
'Expression': 'an expression',
|
||||
'Equation': 'an equation',
|
||||
'Interval': 'an interval'
|
||||
}
|
||||
|
||||
|
||||
def get_single_answer_type_text(answer_type, is_chinese):
|
||||
if '-' in answer_type: # No need now
|
||||
answer_type = answer_type[:answer_type.find('-')]
|
||||
for t in ['Numerical', 'Expression', 'Equation', 'Interval']:
|
||||
if t in answer_type:
|
||||
if is_chinese:
|
||||
return chinese_answer_type_dict[t]
|
||||
else:
|
||||
return english_answer_type_dict[t]
|
||||
exit(f'Error parsing answer type {answer_type}!')
|
||||
|
||||
|
||||
def get_answer_type_text(answer_type, is_chinese, multiple_answer):
|
||||
# 'Tuple' has various meanings in different context, such as position or values of a series of variable,
|
||||
# so it may lead to confusion to directly use 'tuple' in the prompt.
|
||||
if ('Need_human_evaluate' in answer_type) or ('Tuple' in answer_type):
|
||||
full_answer_text = ''
|
||||
else:
|
||||
if not multiple_answer:
|
||||
answer_text = get_single_answer_type_text(answer_type, is_chinese)
|
||||
if is_chinese:
|
||||
full_answer_text = f',答案类型为{answer_text}'
|
||||
else:
|
||||
full_answer_text = f"The answer of The problem should be {answer_text}. "
|
||||
else:
|
||||
if ',' not in answer_type: # Same answer type for all answers
|
||||
answer_text = get_single_answer_type_text(answer_type, is_chinese)
|
||||
if is_chinese:
|
||||
full_answer_text = f',题目有多个答案,答案类型均为{answer_text}'
|
||||
else:
|
||||
full_answer_text = f'The problem has multiple answers, each of them should be {answer_text}. '
|
||||
else:
|
||||
answer_types = answer_type.split(',')
|
||||
answer_types = [get_single_answer_type_text(t, is_chinese) for t in answer_types]
|
||||
if len(set(answer_types)) == 1:
|
||||
answer_text = answer_types[0]
|
||||
if is_chinese:
|
||||
full_answer_text = f',题目有多个答案,答案类型均为{answer_text}'
|
||||
else:
|
||||
full_answer_text = f'The problem has multiple answers, each of them should be {answer_text}. '
|
||||
else:
|
||||
if is_chinese:
|
||||
answer_text = '、'.join(answer_types)
|
||||
full_answer_text = f',题目有多个答案,答案类型分别为{answer_text}'
|
||||
else:
|
||||
answer_text = ', '.join(answer_types)
|
||||
full_answer_text = (
|
||||
f'The problem has multiple answers, with the answers in order being {answer_text}. '
|
||||
)
|
||||
return full_answer_text
|
||||
|
||||
|
||||
def make_input(prompt, question_content):
|
||||
# diversified based on the vllm, which is not implemented temporarily
|
||||
input = prompt + '\n' + question_content
|
||||
return input
|
||||
|
||||
|
||||
sys.set_int_max_str_digits(1000000)
|
||||
# 设置decimal的精度
|
||||
getcontext().prec = 50
|
||||
|
||||
|
||||
class MathJudger:
|
||||
def __init__(self):
|
||||
self.special_signal_map = {
|
||||
"\\left": "",
|
||||
"\\right": "",
|
||||
"∶": ":",
|
||||
",": ",",
|
||||
"$": "",
|
||||
"\\approx": "=",
|
||||
"\\simeq": "=",
|
||||
"\\sim": "=",
|
||||
"^\\prime": "'",
|
||||
"^{\\prime}": "'",
|
||||
"^\\circ": "",
|
||||
"%": "",
|
||||
}
|
||||
self.pi = parse_latex("\\pi")
|
||||
self.precision = 1e-8
|
||||
|
||||
def split_by_comma(self, expr: str):
|
||||
in_bracket_num = 0
|
||||
splitted_expr = []
|
||||
start_idx = 0
|
||||
for i, char in enumerate(expr):
|
||||
if char == "(" or char == "[":
|
||||
in_bracket_num += 1
|
||||
elif char == ")" or char == "]":
|
||||
in_bracket_num -= 1
|
||||
elif char == "," and in_bracket_num == 0:
|
||||
splitted_expr.append(expr[start_idx:i].strip())
|
||||
start_idx = i + 1
|
||||
|
||||
if start_idx < len(expr):
|
||||
splitted_expr.append(expr[start_idx:].strip())
|
||||
|
||||
return splitted_expr
|
||||
|
||||
def trans_plus_minus_sign(self, expr_list: list):
|
||||
new_expr_list = []
|
||||
for expr in expr_list:
|
||||
if "\\pm" in expr:
|
||||
new_expr_list.append(expr.replace("\\pm", "+"))
|
||||
new_expr_list.append(expr.replace("\\pm", "-"))
|
||||
else:
|
||||
new_expr_list.append(expr)
|
||||
|
||||
return new_expr_list
|
||||
|
||||
def judge(self, expression1, expression2, precision=1e-8):
|
||||
# (默认 expression1 为 Ground_Truth)
|
||||
precision = precision if isinstance(precision, list) else [precision]
|
||||
|
||||
try:
|
||||
expression1, expression2 = self.preprocess(expression1, expression2)
|
||||
except:
|
||||
return False
|
||||
if expression1 == expression2:
|
||||
# print("原生相等")
|
||||
return True
|
||||
|
||||
# 去除字符串中的中文字符,因为上面已经判断过了类似回答为"能"或"不能"的含有中文字符的回答情况
|
||||
expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1)
|
||||
expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2)
|
||||
|
||||
expression1 = self.split_by_comma(expression1)
|
||||
expression2 = self.split_by_comma(expression2)
|
||||
|
||||
temp_list1 = self.trans_plus_minus_sign(expression1)
|
||||
temp_list2 = self.trans_plus_minus_sign(expression2)
|
||||
|
||||
# 设计误差值列表
|
||||
if len(precision) <= 1:
|
||||
precision = precision * len(temp_list1)
|
||||
|
||||
if len(temp_list1) != len(temp_list2):
|
||||
return False
|
||||
|
||||
# 判断两个列表中的元素是否可以两两配对,并且两两相等,由此支持多个回答的比较
|
||||
idx = -1
|
||||
while len(temp_list1) != 0:
|
||||
idx = (idx + 1) % len(temp_list1)
|
||||
|
||||
item1 = temp_list1[idx]
|
||||
self.precision = precision[idx]
|
||||
# print(self.precision)
|
||||
|
||||
for item2 in temp_list2:
|
||||
if self.is_equal(item1, item2):
|
||||
temp_list1.remove(item1)
|
||||
temp_list2.remove(item2)
|
||||
precision.remove(self.precision)
|
||||
break
|
||||
else:
|
||||
# If we didn't break from the inner loop, it means no match was found
|
||||
return False
|
||||
|
||||
# If all elements are matched and removed, the lists can be paired
|
||||
return True
|
||||
|
||||
def is_interval(self, epr):
|
||||
return epr.startswith(("(", "[")) and epr.endswith((")", "]"))
|
||||
|
||||
# 在进行数值计算前,需要将sympy中的pi符号替换为pi的近似数值
|
||||
# def sympy_sub_pi(self, expression_sympy):
|
||||
# return expression_sympy.subs(self.pi, math.pi)
|
||||
|
||||
# 默认第一个表达式是 ground_truth
|
||||
def is_equal(self, expression1, expression2):
|
||||
if expression1 == expression2 and expression1 != "" and expression2 != "":
|
||||
# print("原生等价")
|
||||
return True
|
||||
|
||||
# 先判断是否是两个区间,是的话进行判断相等,不相等则返回 False
|
||||
if self.is_interval(expression1) and self.is_interval(expression2):
|
||||
try:
|
||||
if self.interval_equal(expression1, expression2):
|
||||
# print("区间等价")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
# 再判断是否在数值上相等
|
||||
try:
|
||||
if self.numerical_equal(expression1, expression2):
|
||||
# print("数值等价")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# 再判断是否是表达式相等
|
||||
try:
|
||||
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
|
||||
# print("表达式等价")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# 再判断是否是等式相等
|
||||
try:
|
||||
if self.equation_equal(expression1, expression2):
|
||||
# print("等式等价")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
# 判断两个数值在误差允许范围内是否相等
|
||||
def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
|
||||
"""
|
||||
(默认 expression1 为 Ground_Truth)
|
||||
函数: 判读两个数值是否在误差允许范围内相等
|
||||
步骤1: 将可能出现的百分号的情况包含进来
|
||||
步骤2: 使用 math.isclose 函数判断是否相等
|
||||
"""
|
||||
reference = float(expression1)
|
||||
prediction = float(expression2)
|
||||
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
|
||||
for item in gt_result:
|
||||
# if isclose(item, prediction, abs_tol=self.precision, rel_tol=0):
|
||||
if abs(item - prediction) <= self.precision * 1.01:
|
||||
return True
|
||||
return False
|
||||
|
||||
def expression_equal(self, exp1, exp2):
|
||||
"""
|
||||
(默认 expression1 为 Ground_Truth)
|
||||
函数: 判断两个表达式是否在数学意义上等价
|
||||
步骤1: 提取表达式, 防止有的模型会给出"x=1"而不是"1"
|
||||
步骤2: 使用 sympy 库进行等价判断
|
||||
"""
|
||||
|
||||
# 只提取等号右边的表达式,一般左边是所求的量
|
||||
def extract_expression(expression):
|
||||
if "=" in expression:
|
||||
expression = expression.split("=")[1]
|
||||
return expression.strip()
|
||||
|
||||
exp1 = extract_expression(exp1)
|
||||
exp2 = extract_expression(exp2)
|
||||
|
||||
exp_too_long = len(exp1) > 300 or len(exp2) > 300
|
||||
|
||||
# 将表达式转换为 sympy 中能够进行处理的格式
|
||||
expr1_sym = sympify(parse_latex(exp1))
|
||||
expr2_sym = sympify(parse_latex(exp2))
|
||||
|
||||
if expr1_sym == expr2_sym:
|
||||
return True
|
||||
else:
|
||||
expr1_sym = self.sympy_sub_pi(expr1_sym)
|
||||
expr2_sym = self.sympy_sub_pi(expr2_sym)
|
||||
# 如果输入的表达式可以计算出具体数值的话,则将其进行数值计算的比较
|
||||
|
||||
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or (
|
||||
not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
|
||||
return False
|
||||
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
|
||||
try:
|
||||
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
|
||||
print(
|
||||
"These two number can not be calculated by current computer for: "
|
||||
f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\""
|
||||
)
|
||||
return False
|
||||
if exp_too_long:
|
||||
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
||||
return False
|
||||
|
||||
if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
elif exp_too_long:
|
||||
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
simplified_expr = simplify(expr1_sym - expr2_sym)
|
||||
|
||||
num_value = simplified_expr.evalf()
|
||||
|
||||
return abs(num_value) < 1e-3
|
||||
except:
|
||||
return False
|
||||
|
||||
def equation_equal(self, expression1, expression2):
|
||||
"""
|
||||
(默认 expression1 为 Ground_Truth)
|
||||
函数: 判断两个方程是否在数学意义上等价
|
||||
步骤1: 将一个方程/等式化简为标准方程, 即等式的右边严格等于0, 接下来只需要判断两个等式的左边是否"等价"
|
||||
步骤2: 使用 sympy 库计算两个等式左边的商, 如果这个商或者这个商的倒数为整数, 那么数学意义上我们可以推导出这两个方程等价👌
|
||||
"""
|
||||
|
||||
# 将等式的右边都移到左边,并返回一个 sympy 格式的表达式
|
||||
def simplify_equation(latex_eq):
|
||||
# 分割等式的左边和右边
|
||||
lhs, rhs = latex_eq.split('=')
|
||||
|
||||
# 使用 parse_latex 解析 LaTeX 表达式
|
||||
lhs_expr = parse_latex(lhs)
|
||||
rhs_expr = parse_latex(rhs)
|
||||
|
||||
# 创建等式对象
|
||||
equation = Eq(lhs_expr, rhs_expr)
|
||||
|
||||
# 化简等式:将等式右边移到左边
|
||||
simplified_eq = simplify(equation.lhs - equation.rhs)
|
||||
|
||||
return simplified_eq
|
||||
|
||||
expr1_sym = simplify_equation(expression1)
|
||||
expr2_sym = simplify_equation(expression2)
|
||||
|
||||
division_result_1 = simplify(expr1_sym / expr2_sym)
|
||||
division_result_2 = simplify(expr2_sym / expr1_sym)
|
||||
|
||||
# 如果两个方程转换后的式子相除为整数 且非零,则根据推导可知这两个方程等价
|
||||
if (division_result_1.is_Integer and division_result_1 != 0) or (
|
||||
division_result_2.is_Integer and division_result_2 != 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def interval_equal(self, expression1, expression2):
|
||||
# 函数: 判断两个区间是否在数学意义上等价
|
||||
# 步骤1: 简化区间的表达式, 去除无关的符号比如"\left", "\right", 同时将可能出现的"x \in"删去
|
||||
# 步骤2: 对比两个区间的左右符号、中间出现的数学表达式等是否一致
|
||||
|
||||
def compare_two_interval(inter1, inter2):
|
||||
|
||||
# 首先比较两边的括号是否一致,一致的话再进行下一步比较
|
||||
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
|
||||
return False
|
||||
|
||||
inter1 = inter1.strip('[]()')
|
||||
inter2 = inter2.strip('[]()')
|
||||
|
||||
# 分割区间的左右部分
|
||||
items_1 = inter1.split(',')
|
||||
items_2 = inter2.split(',')
|
||||
|
||||
for item_1, item_2 in zip(items_1, items_2):
|
||||
if not self.expression_equal(item_1, item_2):
|
||||
return False
|
||||
return True
|
||||
|
||||
interval1 = expression1
|
||||
interval2 = expression2
|
||||
|
||||
if interval1 == interval2:
|
||||
return True
|
||||
else:
|
||||
inter_list1 = interval1.split("\\cup")
|
||||
inter_list2 = interval2.split("\\cup")
|
||||
|
||||
if len(inter_list1) != len(inter_list2):
|
||||
return False
|
||||
else:
|
||||
for inter1, inter2 in zip(inter_list1, inter_list2):
|
||||
if not compare_two_interval(inter1, inter2):
|
||||
return False
|
||||
return True
|
||||
|
||||
def preprocess(self, expression1, expression2):
|
||||
|
||||
# 尝试捕获box中的内容,如果有多个则以逗号相连返回,如果一个都没有,则报错
|
||||
def extract_boxed_content(latex_str):
|
||||
# 查找所有的 \boxed{...} 结构
|
||||
boxed_matches = re.finditer(r'\\boxed{', latex_str)
|
||||
results = ""
|
||||
|
||||
for match in boxed_matches:
|
||||
start_index = match.end()
|
||||
end_index = start_index
|
||||
stack = 1
|
||||
|
||||
# 从 \boxed{ 之后开始搜索,直到找到对应的闭合括号
|
||||
while stack > 0 and end_index < len(latex_str):
|
||||
if latex_str[end_index] == '{':
|
||||
stack += 1
|
||||
elif latex_str[end_index] == '}':
|
||||
stack -= 1
|
||||
end_index += 1
|
||||
|
||||
if stack == 0:
|
||||
# 提取 \boxed{} 内部的内容
|
||||
content = latex_str[start_index:end_index - 1]
|
||||
results += content + ","
|
||||
else:
|
||||
# 如果括号没有正确闭合,则返回错误信息
|
||||
raise ValueError("Mismatched braces in LaTeX string.")
|
||||
|
||||
# 如果没有匹配到'\boxed{}'字符,则默认提取有内容的文字最后一行中的所有公式部分
|
||||
if results == "":
|
||||
last_line_ans = latex_str.strip().split("\n")[-1]
|
||||
dollar_pattern = r"\$(.*?)\$"
|
||||
answers = re.findall(dollar_pattern, last_line_ans)
|
||||
|
||||
if answers:
|
||||
for ans in answers:
|
||||
results += ans + ","
|
||||
else:
|
||||
results = latex_str
|
||||
|
||||
return results
|
||||
|
||||
def sepcial_symbol_replace(expression):
|
||||
if "\\in " in expression:
|
||||
expression = expression.split("\\in ")[1]
|
||||
|
||||
# 进行特殊字符的替换,这些字符都不影响latex的解析,属于美观/修饰性字符
|
||||
for signal in self.special_signal_map:
|
||||
expression = expression.replace(signal, self.special_signal_map[signal])
|
||||
|
||||
expression = expression.strip("\n$,.:;^_=+`!@#$%^&*~,。")
|
||||
|
||||
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
|
||||
expression = re.sub(pattern, r'\1', expression)
|
||||
|
||||
return expression
|
||||
|
||||
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
|
||||
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
|
||||
|
||||
return exp1, exp2
|
||||
|
||||
def can_compute_power(self, expr):
|
||||
"""
|
||||
Check if the power expression can be computed.
|
||||
|
||||
Parameters:
|
||||
expr (sympy expression): The expression to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the expression can be computed, False otherwise.
|
||||
"""
|
||||
# Check if the expression is a power expression
|
||||
if isinstance(expr, Pow):
|
||||
# Extract the base and the exponent
|
||||
base, exp = expr.as_base_exp()
|
||||
|
||||
# Check if the base and the exponent are numbers
|
||||
if base.is_number and exp.is_number:
|
||||
# Set a threshold for the maximum size of the exponent
|
||||
MAX_EXP = 1000 # This threshold can be adjusted based on the computing environment
|
||||
|
||||
# Check if the exponent is greater than the threshold
|
||||
if abs(exp.evalf()) > MAX_EXP:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
# If the base or the exponent is not a number, we cannot compute the power
|
||||
return False
|
||||
else:
|
||||
# If the expression is not a power expression, return True as it is not the case we are checking for
|
||||
return True
|
||||
|
||||
|
||||
def extract_answer(is_chinese, model_output, is_deepseek=False):
|
||||
# deepseekmath has special answering format
|
||||
if str(model_output) == 'nan':
|
||||
model_output = 'nan'
|
||||
|
||||
if is_deepseek:
|
||||
if is_chinese:
|
||||
matches = re.findall('## 解题答案(.*)', model_output)
|
||||
else:
|
||||
matches = re.findall('The answer is: (.*)', model_output)
|
||||
|
||||
# 检测是否至少找到一个匹配,如果没有就直接整个送进去找\boxed{}
|
||||
if matches:
|
||||
# 如果找到多个匹配,取最后一个
|
||||
model_answer = matches[-1].strip()
|
||||
return model_answer
|
||||
else:
|
||||
return model_output
|
||||
|
||||
if is_chinese:
|
||||
matches = re.findall('所以最终答案是(.*)', model_output)
|
||||
else:
|
||||
matches = re.findall('So the final answer is (.*)', model_output)
|
||||
|
||||
# 检测是否至少找到一个匹配,如果没有就直接整个送进去找\boxed{}
|
||||
if matches:
|
||||
# 如果找到多个匹配,取最后一个
|
||||
model_answer = matches[-1].strip()
|
||||
return model_answer
|
||||
else:
|
||||
return model_output
|
||||
|
||||
|
||||
def calculate_merged_accuracy(reference_dir, text_only):
|
||||
pass
|
||||
Reference in New Issue
Block a user