mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-o 2.6
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from ...smp import *
|
||||
from .multiple_choice import extract_answer_from_item
|
||||
from PIL import Image, ImageOps
|
||||
import torchvision
|
||||
import random
|
||||
@@ -32,9 +33,9 @@ def get_dimension_rating(data_path):
|
||||
def check_ans(pred, gt):
|
||||
flag = False
|
||||
|
||||
pred_list = pred.lower().split(' ')
|
||||
pred_list = pred.lower().strip().split(' ')
|
||||
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
|
||||
gt_list = gt.lower().split(' ')
|
||||
gt_list = gt.lower().strip().split(' ')
|
||||
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
|
||||
if gt_content[-1] == '.':
|
||||
gt_content = gt_content[:-1]
|
||||
@@ -47,6 +48,64 @@ def check_ans(pred, gt):
|
||||
return flag
|
||||
|
||||
|
||||
def check_ans_with_model(pred, gt, model, item, dataset_name='MVBench'):
|
||||
flag = False
|
||||
|
||||
pred_list = pred.lower().strip().split(' ')
|
||||
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
|
||||
gt_list = gt.lower().strip().split(' ')
|
||||
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
|
||||
if gt_content[-1] == '.':
|
||||
gt_content = gt_content[:-1]
|
||||
|
||||
if pred_option.replace('.', '') in gt_option:
|
||||
flag = True
|
||||
elif gt_option in pred_option:
|
||||
flag = True
|
||||
elif extract_answer_from_item(model, item, dataset_name)['opt'] == item['answer']:
|
||||
flag = True
|
||||
|
||||
return flag
|
||||
|
||||
|
||||
def check_ans_advanced(pred, gt):
|
||||
number_table = {
|
||||
0: 'zero',
|
||||
1: 'one',
|
||||
2: 'two',
|
||||
3: 'three',
|
||||
4: 'four',
|
||||
5: 'five',
|
||||
6: 'six',
|
||||
7: 'seven',
|
||||
8: 'eight',
|
||||
9: 'nine',
|
||||
}
|
||||
flag = False
|
||||
|
||||
pred_list = pred.lower().strip().split(' ')
|
||||
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
|
||||
gt_list = gt.lower().strip().split(' ')
|
||||
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
|
||||
if gt_content[-1] == '.':
|
||||
gt_content = gt_content[:-1]
|
||||
|
||||
try:
|
||||
gt_content = number_table[int(gt_content.strip('. \n'))]
|
||||
print(gt_content)
|
||||
except:
|
||||
pass
|
||||
|
||||
if pred_option.replace('.', '') in gt_option:
|
||||
flag = True
|
||||
elif gt_option in pred_option:
|
||||
flag = True
|
||||
elif gt_content.lower().strip('. \n') in pred.lower().strip('. \n'):
|
||||
flag = True
|
||||
|
||||
return flag
|
||||
|
||||
|
||||
class GroupRandomCrop(object):
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
|
||||
Reference in New Issue
Block a user