Files
MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/utils/olympiadbench.py
2025-01-21 15:34:54 +08:00

533 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import json
from math import isclose
import sympy as sp
from sympy import simplify, Eq, sympify, evalf, Pow
from sympy.parsing.latex import parse_latex
import antlr4
from decimal import Decimal, getcontext
from fractions import Fraction
import sys
import math
chinese_answer_type_dict = {
'Numerical': '数值',
'Expression': '表达式',
'Equation': '方程',
'Interval': '区间'
}
english_answer_type_dict = {
'Numerical': 'a numerical value',
'Expression': 'an expression',
'Equation': 'an equation',
'Interval': 'an interval'
}
def get_single_answer_type_text(answer_type, is_chinese):
if '-' in answer_type: # No need now
answer_type = answer_type[:answer_type.find('-')]
for t in ['Numerical', 'Expression', 'Equation', 'Interval']:
if t in answer_type:
if is_chinese:
return chinese_answer_type_dict[t]
else:
return english_answer_type_dict[t]
exit(f'Error parsing answer type {answer_type}!')
def get_answer_type_text(answer_type, is_chinese, multiple_answer):
# 'Tuple' has various meanings in different context, such as position or values of a series of variable,
# so it may lead to confusion to directly use 'tuple' in the prompt.
if ('Need_human_evaluate' in answer_type) or ('Tuple' in answer_type):
full_answer_text = ''
else:
if not multiple_answer:
answer_text = get_single_answer_type_text(answer_type, is_chinese)
if is_chinese:
full_answer_text = f',答案类型为{answer_text}'
else:
full_answer_text = f"The answer of The problem should be {answer_text}. "
else:
if ',' not in answer_type: # Same answer type for all answers
answer_text = get_single_answer_type_text(answer_type, is_chinese)
if is_chinese:
full_answer_text = f',题目有多个答案,答案类型均为{answer_text}'
else:
full_answer_text = f'The problem has multiple answers, each of them should be {answer_text}. '
else:
answer_types = answer_type.split(',')
answer_types = [get_single_answer_type_text(t, is_chinese) for t in answer_types]
if len(set(answer_types)) == 1:
answer_text = answer_types[0]
if is_chinese:
full_answer_text = f',题目有多个答案,答案类型均为{answer_text}'
else:
full_answer_text = f'The problem has multiple answers, each of them should be {answer_text}. '
else:
if is_chinese:
answer_text = ''.join(answer_types)
full_answer_text = f',题目有多个答案,答案类型分别为{answer_text}'
else:
answer_text = ', '.join(answer_types)
full_answer_text = (
f'The problem has multiple answers, with the answers in order being {answer_text}. '
)
return full_answer_text
def make_input(prompt, question_content):
# diversified based on the vllm, which is not implemented temporarily
input = prompt + '\n' + question_content
return input
sys.set_int_max_str_digits(1000000)
# 设置decimal的精度
getcontext().prec = 50
class MathJudger:
def __init__(self):
self.special_signal_map = {
"\\left": "",
"\\right": "",
"": ":",
"": ",",
"$": "",
"\\approx": "=",
"\\simeq": "=",
"\\sim": "=",
"^\\prime": "'",
"^{\\prime}": "'",
"^\\circ": "",
"%": "",
}
self.pi = parse_latex("\\pi")
self.precision = 1e-8
def split_by_comma(self, expr: str):
in_bracket_num = 0
splitted_expr = []
start_idx = 0
for i, char in enumerate(expr):
if char == "(" or char == "[":
in_bracket_num += 1
elif char == ")" or char == "]":
in_bracket_num -= 1
elif char == "," and in_bracket_num == 0:
splitted_expr.append(expr[start_idx:i].strip())
start_idx = i + 1
if start_idx < len(expr):
splitted_expr.append(expr[start_idx:].strip())
return splitted_expr
def trans_plus_minus_sign(self, expr_list: list):
new_expr_list = []
for expr in expr_list:
if "\\pm" in expr:
new_expr_list.append(expr.replace("\\pm", "+"))
new_expr_list.append(expr.replace("\\pm", "-"))
else:
new_expr_list.append(expr)
return new_expr_list
def judge(self, expression1, expression2, precision=1e-8):
# (默认 expression1 为 Ground_Truth)
precision = precision if isinstance(precision, list) else [precision]
try:
expression1, expression2 = self.preprocess(expression1, expression2)
except:
return False
if expression1 == expression2:
# print("原生相等")
return True
# 去除字符串中的中文字符,因为上面已经判断过了类似回答为"能"或"不能"的含有中文字符的回答情况
expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1)
expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2)
expression1 = self.split_by_comma(expression1)
expression2 = self.split_by_comma(expression2)
temp_list1 = self.trans_plus_minus_sign(expression1)
temp_list2 = self.trans_plus_minus_sign(expression2)
# 设计误差值列表
if len(precision) <= 1:
precision = precision * len(temp_list1)
if len(temp_list1) != len(temp_list2):
return False
# 判断两个列表中的元素是否可以两两配对,并且两两相等,由此支持多个回答的比较
idx = -1
while len(temp_list1) != 0:
idx = (idx + 1) % len(temp_list1)
item1 = temp_list1[idx]
self.precision = precision[idx]
# print(self.precision)
for item2 in temp_list2:
if self.is_equal(item1, item2):
temp_list1.remove(item1)
temp_list2.remove(item2)
precision.remove(self.precision)
break
else:
# If we didn't break from the inner loop, it means no match was found
return False
# If all elements are matched and removed, the lists can be paired
return True
def is_interval(self, epr):
return epr.startswith(("(", "[")) and epr.endswith((")", "]"))
# 在进行数值计算前需要将sympy中的pi符号替换为pi的近似数值
# def sympy_sub_pi(self, expression_sympy):
# return expression_sympy.subs(self.pi, math.pi)
# 默认第一个表达式是 ground_truth
def is_equal(self, expression1, expression2):
if expression1 == expression2 and expression1 != "" and expression2 != "":
# print("原生等价")
return True
# 先判断是否是两个区间,是的话进行判断相等,不相等则返回 False
if self.is_interval(expression1) and self.is_interval(expression2):
try:
if self.interval_equal(expression1, expression2):
# print("区间等价")
return True
except:
return False
# 再判断是否在数值上相等
try:
if self.numerical_equal(expression1, expression2):
# print("数值等价")
return True
except:
pass
# 再判断是否是表达式相等
try:
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
# print("表达式等价")
return True
except:
pass
# 再判断是否是等式相等
try:
if self.equation_equal(expression1, expression2):
# print("等式等价")
return True
except:
pass
return False
# 判断两个数值在误差允许范围内是否相等
def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
"""
(默认 expression1 为 Ground_Truth)
函数: 判读两个数值是否在误差允许范围内相等
步骤1: 将可能出现的百分号的情况包含进来
步骤2: 使用 math.isclose 函数判断是否相等
"""
reference = float(expression1)
prediction = float(expression2)
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
# if isclose(item, prediction, abs_tol=self.precision, rel_tol=0):
if abs(item - prediction) <= self.precision * 1.01:
return True
return False
def expression_equal(self, exp1, exp2):
"""
(默认 expression1 为 Ground_Truth)
函数: 判断两个表达式是否在数学意义上等价
步骤1: 提取表达式, 防止有的模型会给出"x=1"而不是"1"
步骤2: 使用 sympy 库进行等价判断
"""
# 只提取等号右边的表达式,一般左边是所求的量
def extract_expression(expression):
if "=" in expression:
expression = expression.split("=")[1]
return expression.strip()
exp1 = extract_expression(exp1)
exp2 = extract_expression(exp2)
exp_too_long = len(exp1) > 300 or len(exp2) > 300
# 将表达式转换为 sympy 中能够进行处理的格式
expr1_sym = sympify(parse_latex(exp1))
expr2_sym = sympify(parse_latex(exp2))
if expr1_sym == expr2_sym:
return True
else:
expr1_sym = self.sympy_sub_pi(expr1_sym)
expr2_sym = self.sympy_sub_pi(expr2_sym)
# 如果输入的表达式可以计算出具体数值的话,则将其进行数值计算的比较
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or (
not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
return False
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
try:
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
print(
"These two number can not be calculated by current computer for: "
f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\""
)
return False
if exp_too_long:
print(f'Expression {exp1} or {exp2} is too long to compute. ')
return False
if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
return True
else:
return False
except:
return False
elif exp_too_long:
print(f'Expression {exp1} or {exp2} is too long to compute. ')
return False
else:
try:
simplified_expr = simplify(expr1_sym - expr2_sym)
num_value = simplified_expr.evalf()
return abs(num_value) < 1e-3
except:
return False
def equation_equal(self, expression1, expression2):
"""
(默认 expression1 为 Ground_Truth)
函数: 判断两个方程是否在数学意义上等价
步骤1: 将一个方程/等式化简为标准方程, 即等式的右边严格等于0, 接下来只需要判断两个等式的左边是否"等价"
步骤2: 使用 sympy 库计算两个等式左边的商, 如果这个商或者这个商的倒数为整数, 那么数学意义上我们可以推导出这两个方程等价👌
"""
# 将等式的右边都移到左边,并返回一个 sympy 格式的表达式
def simplify_equation(latex_eq):
# 分割等式的左边和右边
lhs, rhs = latex_eq.split('=')
# 使用 parse_latex 解析 LaTeX 表达式
lhs_expr = parse_latex(lhs)
rhs_expr = parse_latex(rhs)
# 创建等式对象
equation = Eq(lhs_expr, rhs_expr)
# 化简等式:将等式右边移到左边
simplified_eq = simplify(equation.lhs - equation.rhs)
return simplified_eq
expr1_sym = simplify_equation(expression1)
expr2_sym = simplify_equation(expression2)
division_result_1 = simplify(expr1_sym / expr2_sym)
division_result_2 = simplify(expr2_sym / expr1_sym)
# 如果两个方程转换后的式子相除为整数 且非零,则根据推导可知这两个方程等价
if (division_result_1.is_Integer and division_result_1 != 0) or (
division_result_2.is_Integer and division_result_2 != 0):
return True
else:
return False
def interval_equal(self, expression1, expression2):
# 函数: 判断两个区间是否在数学意义上等价
# 步骤1: 简化区间的表达式, 去除无关的符号比如"\left", "\right", 同时将可能出现的"x \in"删去
# 步骤2: 对比两个区间的左右符号、中间出现的数学表达式等是否一致
def compare_two_interval(inter1, inter2):
# 首先比较两边的括号是否一致,一致的话再进行下一步比较
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
return False
inter1 = inter1.strip('[]()')
inter2 = inter2.strip('[]()')
# 分割区间的左右部分
items_1 = inter1.split(',')
items_2 = inter2.split(',')
for item_1, item_2 in zip(items_1, items_2):
if not self.expression_equal(item_1, item_2):
return False
return True
interval1 = expression1
interval2 = expression2
if interval1 == interval2:
return True
else:
inter_list1 = interval1.split("\\cup")
inter_list2 = interval2.split("\\cup")
if len(inter_list1) != len(inter_list2):
return False
else:
for inter1, inter2 in zip(inter_list1, inter_list2):
if not compare_two_interval(inter1, inter2):
return False
return True
def preprocess(self, expression1, expression2):
# 尝试捕获box中的内容如果有多个则以逗号相连返回如果一个都没有则报错
def extract_boxed_content(latex_str):
# 查找所有的 \boxed{...} 结构
boxed_matches = re.finditer(r'\\boxed{', latex_str)
results = ""
for match in boxed_matches:
start_index = match.end()
end_index = start_index
stack = 1
# 从 \boxed{ 之后开始搜索,直到找到对应的闭合括号
while stack > 0 and end_index < len(latex_str):
if latex_str[end_index] == '{':
stack += 1
elif latex_str[end_index] == '}':
stack -= 1
end_index += 1
if stack == 0:
# 提取 \boxed{} 内部的内容
content = latex_str[start_index:end_index - 1]
results += content + ","
else:
# 如果括号没有正确闭合,则返回错误信息
raise ValueError("Mismatched braces in LaTeX string.")
# 如果没有匹配到'\boxed{}'字符,则默认提取有内容的文字最后一行中的所有公式部分
if results == "":
last_line_ans = latex_str.strip().split("\n")[-1]
dollar_pattern = r"\$(.*?)\$"
answers = re.findall(dollar_pattern, last_line_ans)
if answers:
for ans in answers:
results += ans + ","
else:
results = latex_str
return results
def sepcial_symbol_replace(expression):
if "\\in " in expression:
expression = expression.split("\\in ")[1]
# 进行特殊字符的替换这些字符都不影响latex的解析属于美观/修饰性字符
for signal in self.special_signal_map:
expression = expression.replace(signal, self.special_signal_map[signal])
expression = expression.strip("\n$,.:;^_=+`!@#$%^&*~,。")
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
expression = re.sub(pattern, r'\1', expression)
return expression
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
return exp1, exp2
def can_compute_power(self, expr):
"""
Check if the power expression can be computed.
Parameters:
expr (sympy expression): The expression to check.
Returns:
bool: True if the expression can be computed, False otherwise.
"""
# Check if the expression is a power expression
if isinstance(expr, Pow):
# Extract the base and the exponent
base, exp = expr.as_base_exp()
# Check if the base and the exponent are numbers
if base.is_number and exp.is_number:
# Set a threshold for the maximum size of the exponent
MAX_EXP = 1000 # This threshold can be adjusted based on the computing environment
# Check if the exponent is greater than the threshold
if abs(exp.evalf()) > MAX_EXP:
return False
else:
return True
else:
# If the base or the exponent is not a number, we cannot compute the power
return False
else:
# If the expression is not a power expression, return True as it is not the case we are checking for
return True
def extract_answer(is_chinese, model_output, is_deepseek=False):
# deepseekmath has special answering format
if str(model_output) == 'nan':
model_output = 'nan'
if is_deepseek:
if is_chinese:
matches = re.findall('## 解题答案(.*)', model_output)
else:
matches = re.findall('The answer is: (.*)', model_output)
# 检测是否至少找到一个匹配,如果没有就直接整个送进去找\boxed{}
if matches:
# 如果找到多个匹配,取最后一个
model_answer = matches[-1].strip()
return model_answer
else:
return model_output
if is_chinese:
matches = re.findall('所以最终答案是(.*)', model_output)
else:
matches = re.findall('So the final answer is (.*)', model_output)
# 检测是否至少找到一个匹配,如果没有就直接整个送进去找\boxed{}
if matches:
# 如果找到多个匹配,取最后一个
model_answer = matches[-1].strip()
return model_answer
else:
return model_output
def calculate_merged_accuracy(reference_dir, text_only):
pass