Modify eval_mm for MiniCPM-o 2.6

This commit is contained in:
Poppy Xu
2025-01-21 15:34:54 +08:00
parent ec68cefc17
commit d8f382e157
82 changed files with 14279 additions and 843 deletions

View File

@@ -22,7 +22,7 @@ from eval_utils.vqa_evaluate import *
def get_model(args):
if args.model_name == '':
raise Exception('Model name cannot be empty str!')
from models.MiniCPM.minicpmv import MiniCPM_V, MiniCPM_V_2_6
from models.MiniCPM.minicpmv import MiniCPM_V, MiniCPM_V_2_6, MiniCPM_o_2_6
model_path = args.model_path
ckpt = args.ckpt
@@ -30,6 +30,8 @@ def get_model(args):
model = MiniCPM_V(model_path=model_path, ckpt=ckpt, device=args.device)
elif args.model_name == 'minicpmv26':
model = MiniCPM_V_2_6(model_path=model_path, ckpt=ckpt, device=args.device)
elif args.model_name == 'minicpmo26':
model = MiniCPM_o_2_6(model_path=model_path, ckpt=ckpt, device=args.device)
else:
raise Exception(f"Unexpected Moedel Name {args.model_name}!")
@@ -67,15 +69,16 @@ def main(args):
dataset = docVQADataset(args.docVQA_image_dir, args.docVQA_ann_path)
if max_sample_num is not None:
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time, batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time, \
batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
result['docVQA'] = acc
if args.eval_docVQATest or args.eval_all:
target_dataset = "docVQATest"
dataset = docVQATESTDataset(args.docVQATest_image_dir, args.docVQATest_ann_path)
if max_sample_num is not None:
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
acc = evaluate_VQA(model, dataset, args.model_name, target_dataset, time, batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQATest', time, \
batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
result['docVQATest'] = acc
if torch.distributed.is_initialized():