Modify eval_mm for MiniCPM-o 2.6

This commit is contained in:
Poppy Xu
2025-01-21 15:34:54 +08:00
parent ec68cefc17
commit d8f382e157
82 changed files with 14279 additions and 843 deletions

View File

@@ -1,4 +1,5 @@
from ...smp import *
from .multiple_choice import extract_answer_from_item
from PIL import Image, ImageOps
import torchvision
import random
@@ -32,9 +33,9 @@ def get_dimension_rating(data_path):
def check_ans(pred, gt):
flag = False
pred_list = pred.lower().split(' ')
pred_list = pred.lower().strip().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().split(' ')
gt_list = gt.lower().strip().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
@@ -47,6 +48,64 @@ def check_ans(pred, gt):
return flag
def check_ans_with_model(pred, gt, model, item, dataset_name='MVBench'):
flag = False
pred_list = pred.lower().strip().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().strip().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
elif extract_answer_from_item(model, item, dataset_name)['opt'] == item['answer']:
flag = True
return flag
def check_ans_advanced(pred, gt):
number_table = {
0: 'zero',
1: 'one',
2: 'two',
3: 'three',
4: 'four',
5: 'five',
6: 'six',
7: 'seven',
8: 'eight',
9: 'nine',
}
flag = False
pred_list = pred.lower().strip().split(' ')
pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
gt_list = gt.lower().strip().split(' ')
gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
if gt_content[-1] == '.':
gt_content = gt_content[:-1]
try:
gt_content = number_table[int(gt_content.strip('. \n'))]
print(gt_content)
except:
pass
if pred_option.replace('.', '') in gt_option:
flag = True
elif gt_option in pred_option:
flag = True
elif gt_content.lower().strip('. \n') in pred.lower().strip('. \n'):
flag = True
return flag
class GroupRandomCrop(object):
def __init__(self, size):
if isinstance(size, numbers.Number):