mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 09:49:20 +08:00
172 lines
5.1 KiB
Python
172 lines
5.1 KiB
Python
from ...smp import *
|
|
from ...utils import can_infer
|
|
try:
|
|
from latex2sympy2 import latex2sympy
|
|
except Exception as e:
|
|
logging.critical(f'{type(e)}: {e}')
|
|
logging.critical('Please install latex2sympy2 by running "pip install latex2sympy2"')
|
|
|
|
FAIL_MSG = 'Failed to obtain answer via API.'
|
|
|
|
|
|
def is_equal(asw: str, gt_asw: str) -> bool:
|
|
if not isinstance(asw, str) != str or not isinstance(gt_asw, str):
|
|
print('Warning: input is not string')
|
|
print(asw, gt_asw)
|
|
asw = str(asw).lower().strip()
|
|
gt_asw = str(gt_asw).lower().strip()
|
|
if gt_asw == asw:
|
|
return True
|
|
try:
|
|
a = eval(gt_asw)
|
|
b = eval(asw)
|
|
if abs(a - b) < 1e-6:
|
|
return True
|
|
except:
|
|
pass
|
|
try:
|
|
a = latex2sympy(gt_asw)
|
|
b = latex2sympy(asw)
|
|
if abs(eval(str(a)) - eval(str(b))) < 1e-6:
|
|
return True
|
|
if abs(a - b) < 1e-6:
|
|
return True
|
|
except:
|
|
pass
|
|
return False
|
|
|
|
|
|
def get_gpt4_ICE():
|
|
example_1 = """
|
|
Hint: Please answer the question and provide the final answer at the end.\n
|
|
Question: Which number is missing?\n
|
|
Model response: The number missing in the sequence is 14.\n
|
|
Extracted answer: 14
|
|
"""
|
|
|
|
example_2 = """
|
|
Hint: Please answer the question and provide the final answer at the end.\n
|
|
Question: What is the fraction of females facing the camera?\n
|
|
Model response: The fraction of females facing the camera is 0.6,
|
|
which means that six out of ten females in the group are facing the camera.\n
|
|
Extracted answer: 0.6
|
|
"""
|
|
|
|
example_3 = """
|
|
Hint: Please answer the question and provide the final answer at the end.\n
|
|
Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
|
|
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
|
|
Extracted answer: 1.45
|
|
"""
|
|
|
|
example_4 = """
|
|
Hint: Please answer the question and provide the final answer at the end.\n
|
|
Question: Between which two years does the line graph saw its maximum peak?\n
|
|
Model response: The line graph saw its maximum peak between 2007 and 2008.\n
|
|
Extracted answer: [2007, 2008]
|
|
"""
|
|
|
|
example_5 = """
|
|
Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
|
|
Question: What fraction of the shape is blue?\n
|
|
Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
|
|
Model response: The correct answer is (B) 8/11.\n
|
|
Extracted answer: B
|
|
"""
|
|
|
|
return [example_1, example_2, example_3, example_4, example_5]
|
|
|
|
|
|
def build_mathv_gpt4_prompt(line):
|
|
task_description = """
|
|
Please read the following example.
|
|
Then extract the answer from the model response and type it at the end of the prompt.\n
|
|
"""
|
|
question = line['question']
|
|
prediction = str(line['prediction'])
|
|
prompt = task_description
|
|
examples = get_gpt4_ICE()
|
|
for example in examples:
|
|
prompt += example + '\n'
|
|
prompt += question + '\n'
|
|
prompt += 'Model respone: ' + prediction
|
|
prompt += 'Extracted answer:'
|
|
return prompt
|
|
|
|
|
|
def list_to_dict(lst):
|
|
return {chr(65 + i): val for i, val in enumerate(lst)}
|
|
|
|
|
|
def post_check(line, prefetch=False):
|
|
res = None
|
|
ans = line['answer']
|
|
response = line['prediction'] if prefetch else line['res']
|
|
try:
|
|
if len(eval(line['choices'])) > 0:
|
|
ans = line['answer']
|
|
choices = list_to_dict(eval(line['choices']))
|
|
res = can_infer(response, choices)
|
|
if prefetch:
|
|
return res
|
|
else:
|
|
res = str(response)
|
|
ans = str(ans)
|
|
except ValueError:
|
|
pass
|
|
|
|
if is_equal(res, ans):
|
|
return res if prefetch else True
|
|
else:
|
|
return False
|
|
|
|
|
|
def MATH_V_auxeval(model, line):
|
|
prompt = build_mathv_gpt4_prompt(line)
|
|
log = ''
|
|
retry = 5
|
|
if post_check(line, prefetch=True):
|
|
res = post_check(line, prefetch=True)
|
|
return dict(log='Prefetch succeed', res=res)
|
|
for i in range(retry):
|
|
prediction = line['prediction']
|
|
res = model.generate(prompt, temperature=i * 0.5)
|
|
|
|
if FAIL_MSG in res:
|
|
log += f'Try {i}: output is {prediction}, failed to parse.\n'
|
|
else:
|
|
log += 'Succeed'
|
|
return dict(log=log, res=res)
|
|
log += 'All 5 retries failed.\n'
|
|
return dict(log=log, res='')
|
|
|
|
|
|
def MATH_V_acc(result_file):
|
|
data = load(result_file)
|
|
tot = defaultdict(lambda: 0)
|
|
fetch = defaultdict(lambda: 0)
|
|
hit = defaultdict(lambda: 0)
|
|
lt = len(data)
|
|
for i in range(lt):
|
|
item = data.iloc[i]
|
|
cate = item['category']
|
|
tot['Overall'] += 1
|
|
tot[cate] += 1
|
|
if item['log'] == 'Prefetch succeed':
|
|
fetch['Overall'] += 1
|
|
fetch[cate] += 1
|
|
if post_check(item, prefetch=False):
|
|
hit['Overall'] += 1
|
|
hit[cate] += 1
|
|
|
|
res = defaultdict(list)
|
|
for k in tot.keys():
|
|
res['Subject'].append(k)
|
|
res['tot'].append(tot[k])
|
|
res['prefetch'].append(fetch[k])
|
|
res['hit'].append(hit[k])
|
|
res['prefetch_rate'].append(fetch[k] / tot[k] * 100)
|
|
res['acc'].append(hit[k] / tot[k] * 100)
|
|
res = pd.DataFrame(res).sort_values('Subject', ignore_index=True)
|
|
return res
|