mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 09:49:20 +08:00
Add eval_mm dir
This commit is contained in:
3
eval_mm/vqaeval/README.md
Normal file
3
eval_mm/vqaeval/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# vqa-eval
|
||||
|
||||
contains vqa_eval kit from the server.
|
||||
0
eval_mm/vqaeval/datasets/__init__.py
Normal file
0
eval_mm/vqaeval/datasets/__init__.py
Normal file
116
eval_mm/vqaeval/datasets/vqa_dataset.py
Normal file
116
eval_mm/vqaeval/datasets/vqa_dataset.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
def prompt_processor(prompt):
|
||||
if prompt.startswith('OCR tokens: '):
|
||||
pattern = r"Question: (.*?) Short answer:"
|
||||
match = re.search(pattern, prompt, re.DOTALL)
|
||||
question = match.group(1)
|
||||
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
|
||||
if prompt.startswith('Reference OCR token:'):
|
||||
question = prompt.split('\n')[1]
|
||||
else:
|
||||
question = prompt.split('\n')[0]
|
||||
elif len(prompt.split('\n')) == 2:
|
||||
question = prompt.split('\n')[0]
|
||||
else:
|
||||
assert False
|
||||
|
||||
return question.lower()
|
||||
|
||||
class textVQADataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir="./downloads/TextVQA/train_images",
|
||||
ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json",
|
||||
):
|
||||
self.data = json.load(open(ann_path, "r"))["data"]
|
||||
self.image_dir = image_dir
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
question = self.data[idx]['question']
|
||||
answers = self.data[idx]['answers']
|
||||
img_id = self.data[idx]['image_id']
|
||||
qid = self.data[idx]['question_id']
|
||||
img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
|
||||
|
||||
item = {
|
||||
"question_id": qid,
|
||||
"image_path": img_path,
|
||||
"question": question,
|
||||
"gt_answers": answers
|
||||
}
|
||||
|
||||
return item
|
||||
|
||||
class docVQADataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir= "./downloads/DocVQA/spdocvqa_images",
|
||||
ann_path= "./downloads/DocVQA/val_v1.0_withQT.json",
|
||||
ocr_token_path=None
|
||||
):
|
||||
|
||||
self.data = json.load(open(ann_path, "r"))["data"]
|
||||
self.image_dir = image_dir
|
||||
self.ann_path = ann_path
|
||||
if ocr_token_path:
|
||||
self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
question_id = self.data[idx]['questionId']
|
||||
relative_img_path = self.data[idx]['image']
|
||||
corrected_relative_img_path = relative_img_path.replace("documents", "images")
|
||||
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
|
||||
question = self.data[idx]['question']
|
||||
answers = self.data[idx]['answers']
|
||||
|
||||
question_type = self.data[idx]['question_types']
|
||||
|
||||
return {
|
||||
"question_id": question_id,
|
||||
"image_path": img_path,
|
||||
"question": question,
|
||||
"gt_answers": answers,
|
||||
'question_type': question_type,
|
||||
}
|
||||
|
||||
|
||||
class docVQATESTDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir= "./downloads/DocVQA/spdocvqa_images",
|
||||
ann_path= "./downloads/DocVQA/test_v1.0.json",
|
||||
ocr_token_path=None
|
||||
):
|
||||
|
||||
self.data = json.load(open(ann_path, "r"))["data"]
|
||||
self.image_dir = image_dir
|
||||
self.ann_path = ann_path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
question_id = self.data[idx]['questionId']
|
||||
relative_img_path = self.data[idx]['image']
|
||||
corrected_relative_img_path = relative_img_path.replace("documents", "images")
|
||||
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
|
||||
question = self.data[idx]['question']
|
||||
|
||||
|
||||
return {
|
||||
"question_id": question_id,
|
||||
"image_path": img_path,
|
||||
"question": question,
|
||||
"gt_answers": "",
|
||||
'question_type': "",
|
||||
}
|
||||
97
eval_mm/vqaeval/eval.py
Normal file
97
eval_mm/vqaeval/eval.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import sys
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
sys.path.append(os.path.join(script_dir, '..'))
|
||||
|
||||
from datasets.vqa_dataset import docVQADataset, docVQATESTDataset, textVQADataset
|
||||
|
||||
|
||||
print(torch.__version__)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from eval_utils.getargs import parse_args
|
||||
from eval_utils.vqa_evaluate import *
|
||||
|
||||
|
||||
def get_model(args):
|
||||
if args.model_name=='':
|
||||
raise Exception('Model name cannot be empty str!')
|
||||
from models.MiniCPM.minicpmv import MiniCPM_V
|
||||
model_path = args.model_path
|
||||
ckpt = args.ckpt
|
||||
model = MiniCPM_V(model_path=model_path, ckpt=ckpt, device=args.device)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
np.random.seed(0)
|
||||
max_sample_num = None
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl',
|
||||
world_size=int(os.getenv('WORLD_SIZE', '1')),
|
||||
rank=int(os.getenv('RANK', '0')),
|
||||
)
|
||||
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
|
||||
print(f'Init Rank-{torch.distributed.get_rank()}')
|
||||
if torch.distributed.is_initialized():
|
||||
args.device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
|
||||
model = get_model(args)
|
||||
|
||||
result = {}
|
||||
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
|
||||
if args.eval_textVQA or args.eval_all:
|
||||
dataset = textVQADataset(args.textVQA_image_dir, args.textVQA_ann_path)
|
||||
if max_sample_num is not None:
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'textVQA', time, \
|
||||
batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
|
||||
result['textVQA'] = acc
|
||||
|
||||
if args.eval_docVQA or args.eval_all:
|
||||
dataset = docVQADataset(args.docVQA_image_dir, args.docVQA_ann_path)
|
||||
if max_sample_num is not None:
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time, batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
|
||||
result['docVQA'] = acc
|
||||
|
||||
if args.eval_docVQATest or args.eval_all:
|
||||
target_dataset = "docVQATest"
|
||||
dataset = docVQATESTDataset(args.docVQATest_image_dir, args.docVQATest_ann_path)
|
||||
if max_sample_num is not None:
|
||||
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
||||
acc = evaluate_VQA(model, dataset, args.model_name, target_dataset, time, batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
|
||||
result['docVQATest'] = acc
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
|
||||
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
|
||||
return None
|
||||
|
||||
result_path = os.path.join(os.path.join(args.answer_path, args.model_name), 'result.json')
|
||||
|
||||
output_flag = False
|
||||
for k, v in result.items():
|
||||
if v > 0.0:
|
||||
output_flag = True
|
||||
break
|
||||
|
||||
if output_flag:
|
||||
with open(result_path, "w") as f:
|
||||
f.write(json.dumps(result, indent=4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
main(args)
|
||||
40
eval_mm/vqaeval/eval_utils/cal_metric.py
Normal file
40
eval_mm/vqaeval/eval_utils/cal_metric.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
import glob
|
||||
import re
|
||||
|
||||
def has_word(sentence, word):
|
||||
pattern = r"\b" + re.escape(word) + r"\b"
|
||||
match = re.search(pattern, sentence)
|
||||
if match:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
def remove_special_chars(s):
|
||||
pattern = r"[^a-zA-Z0-9\s]"
|
||||
s = re.sub(pattern, "", s)
|
||||
return s
|
||||
|
||||
for model in glob.glob('./answer_save/*'):
|
||||
print(model, ':')
|
||||
result_list = sorted(glob.glob(f'{model}/*.json'))
|
||||
for task_result_path in result_list:
|
||||
taskname = task_result_path.split('/')[-1]
|
||||
taskname = taskname.split('.')[0]
|
||||
if taskname not in ['IIIT5K', 'svt', 'IC13_857', 'IC15_1811', 'svtp', 'ct80',
|
||||
'cocotext', 'ctw', 'totaltext', 'HOST']:
|
||||
continue
|
||||
|
||||
correct = 0
|
||||
num = 0
|
||||
with open(task_result_path, 'r') as f:
|
||||
dict = json.load(f)[:100]
|
||||
for i in range(len(dict)):
|
||||
gt_answers = dict[i]['gt_answers']
|
||||
answer = dict[i]['answer']
|
||||
gt_answers = remove_special_chars(gt_answers).lower()
|
||||
answer = remove_special_chars(answer).lower()
|
||||
if has_word(answer, gt_answers):
|
||||
correct+=1
|
||||
num+=1
|
||||
print(f'{taskname:10s}:{float(correct)/num*100:.2f}')
|
||||
print('=' * 32)
|
||||
62
eval_mm/vqaeval/eval_utils/getargs.py
Normal file
62
eval_mm/vqaeval/eval_utils/getargs.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
|
||||
parser.add_argument('--local-rank', type=int, default=0, help='Local rank for distributed training')
|
||||
|
||||
# textVQA
|
||||
parser.add_argument("--textVQA_image_dir", type=str, default="")
|
||||
parser.add_argument("--textVQA_ann_path", type=str, default="")
|
||||
|
||||
# docVQA
|
||||
parser.add_argument("--docVQA_image_dir", type=str, default="")
|
||||
parser.add_argument("--docVQA_ann_path", type=str, default="")
|
||||
|
||||
# docVQATest
|
||||
parser.add_argument("--docVQATest_image_dir", type=str, default="")
|
||||
parser.add_argument("--docVQATest_ann_path", type=str, default="")
|
||||
|
||||
# result path
|
||||
parser.add_argument("--answer_path", type=str, default="./answers-new")
|
||||
|
||||
# eval
|
||||
parser.add_argument(
|
||||
"--eval_textVQA",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on textVQA."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_docVQA",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on docVQA."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_docVQATest",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate on docVQA."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--eval_all",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate all datasets"
|
||||
)
|
||||
|
||||
parser.add_argument("--model_name", type=str, default="")
|
||||
parser.add_argument("--model_path", type=str, default="")
|
||||
|
||||
parser.add_argument("--generate_method", type=str, default="", help="generate with interleave or not.")
|
||||
|
||||
parser.add_argument("--device", type=str, default="cuda:0")
|
||||
parser.add_argument('--batchsize', type=int, default=1, help='Batch size for processing.')
|
||||
|
||||
parser.add_argument("--ckpt", type=str, default="")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
446
eval_mm/vqaeval/eval_utils/vqa_evaluate.py
Normal file
446
eval_mm/vqaeval/eval_utils/vqa_evaluate.py
Normal file
@@ -0,0 +1,446 @@
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class InferenceSampler(torch.utils.data.sampler.Sampler):
|
||||
|
||||
def __init__(self, size):
|
||||
self._size = int(size)
|
||||
assert size > 0
|
||||
self._rank = torch.distributed.get_rank()
|
||||
self._world_size = torch.distributed.get_world_size()
|
||||
self._local_indices = self._get_local_indices(size, self._world_size,
|
||||
self._rank)
|
||||
|
||||
@staticmethod
|
||||
def _get_local_indices(total_size, world_size, rank):
|
||||
shard_size = total_size // world_size
|
||||
left = total_size % world_size
|
||||
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
|
||||
|
||||
begin = sum(shard_sizes[:rank])
|
||||
end = min(sum(shard_sizes[:rank + 1]), total_size)
|
||||
return range(begin, end)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._local_indices
|
||||
|
||||
def __len__(self):
|
||||
return len(self._local_indices)
|
||||
|
||||
def collate_fn_vqa(batches):
|
||||
'''
|
||||
'''
|
||||
image_paths = [_['image_path'] for _ in batches]
|
||||
questions = [_['question'] for _ in batches]
|
||||
gt_answers = [_['gt_answers'] for _ in batches]
|
||||
ocr_tokens = [_['ocr_tokens'] if 'ocr_tokens' in _ else None for _ in batches]
|
||||
question_ids = [_['question_id'] if 'question_id' in _ else None for _ in batches]
|
||||
question_type = [_['question_type'] if 'question_type' in _ else None for _ in batches]
|
||||
|
||||
return image_paths, questions, gt_answers, ocr_tokens, question_ids, question_type
|
||||
|
||||
def has_word(sentence, word):
|
||||
if word[0].isalnum():
|
||||
start_pattern = r"\b"
|
||||
else:
|
||||
start_pattern = r""
|
||||
|
||||
if word[-1].isalnum():
|
||||
end_pattern = r"\b"
|
||||
else:
|
||||
end_pattern = r""
|
||||
|
||||
pattern = start_pattern + re.escape(word) + end_pattern
|
||||
match = re.search(pattern, sentence)
|
||||
return bool(match)
|
||||
|
||||
def remove_special_chars(s):
|
||||
pattern = r"[^a-zA-Z0-9\s]"
|
||||
s = re.sub(pattern, "", s)
|
||||
return s
|
||||
|
||||
def levenshtein_distance(s1, s2):
|
||||
if len(s1) > len(s2):
|
||||
s1, s2 = s2, s1
|
||||
|
||||
distances = range(len(s1) + 1)
|
||||
for i2, c2 in enumerate(s2):
|
||||
distances_ = [i2+1]
|
||||
for i1, c1 in enumerate(s1):
|
||||
if c1 == c2:
|
||||
distances_.append(distances[i1])
|
||||
else:
|
||||
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
||||
distances = distances_
|
||||
return distances[-1]
|
||||
|
||||
class VQAEval:
|
||||
def __init__(self):
|
||||
self.contractions = {
|
||||
"aint": "ain't",
|
||||
"arent": "aren't",
|
||||
"cant": "can't",
|
||||
"couldve": "could've",
|
||||
"couldnt": "couldn't",
|
||||
"couldn'tve": "couldn't've",
|
||||
"couldnt've": "couldn't've",
|
||||
"didnt": "didn't",
|
||||
"doesnt": "doesn't",
|
||||
"dont": "don't",
|
||||
"hadnt": "hadn't",
|
||||
"hadnt've": "hadn't've",
|
||||
"hadn'tve": "hadn't've",
|
||||
"hasnt": "hasn't",
|
||||
"havent": "haven't",
|
||||
"hed": "he'd",
|
||||
"hed've": "he'd've",
|
||||
"he'dve": "he'd've",
|
||||
"hes": "he's",
|
||||
"howd": "how'd",
|
||||
"howll": "how'll",
|
||||
"hows": "how's",
|
||||
"Id've": "I'd've",
|
||||
"I'dve": "I'd've",
|
||||
"Im": "I'm",
|
||||
"Ive": "I've",
|
||||
"isnt": "isn't",
|
||||
"itd": "it'd",
|
||||
"itd've": "it'd've",
|
||||
"it'dve": "it'd've",
|
||||
"itll": "it'll",
|
||||
"let's": "let's",
|
||||
"maam": "ma'am",
|
||||
"mightnt": "mightn't",
|
||||
"mightnt've": "mightn't've",
|
||||
"mightn'tve": "mightn't've",
|
||||
"mightve": "might've",
|
||||
"mustnt": "mustn't",
|
||||
"mustve": "must've",
|
||||
"neednt": "needn't",
|
||||
"notve": "not've",
|
||||
"oclock": "o'clock",
|
||||
"oughtnt": "oughtn't",
|
||||
"ow's'at": "'ow's'at",
|
||||
"'ows'at": "'ow's'at",
|
||||
"'ow'sat": "'ow's'at",
|
||||
"shant": "shan't",
|
||||
"shed've": "she'd've",
|
||||
"she'dve": "she'd've",
|
||||
"she's": "she's",
|
||||
"shouldve": "should've",
|
||||
"shouldnt": "shouldn't",
|
||||
"shouldnt've": "shouldn't've",
|
||||
"shouldn'tve": "shouldn't've",
|
||||
"somebody'd": "somebodyd",
|
||||
"somebodyd've": "somebody'd've",
|
||||
"somebody'dve": "somebody'd've",
|
||||
"somebodyll": "somebody'll",
|
||||
"somebodys": "somebody's",
|
||||
"someoned": "someone'd",
|
||||
"someoned've": "someone'd've",
|
||||
"someone'dve": "someone'd've",
|
||||
"someonell": "someone'll",
|
||||
"someones": "someone's",
|
||||
"somethingd": "something'd",
|
||||
"somethingd've": "something'd've",
|
||||
"something'dve": "something'd've",
|
||||
"somethingll": "something'll",
|
||||
"thats": "that's",
|
||||
"thered": "there'd",
|
||||
"thered've": "there'd've",
|
||||
"there'dve": "there'd've",
|
||||
"therere": "there're",
|
||||
"theres": "there's",
|
||||
"theyd": "they'd",
|
||||
"theyd've": "they'd've",
|
||||
"they'dve": "they'd've",
|
||||
"theyll": "they'll",
|
||||
"theyre": "they're",
|
||||
"theyve": "they've",
|
||||
"twas": "'twas",
|
||||
"wasnt": "wasn't",
|
||||
"wed've": "we'd've",
|
||||
"we'dve": "we'd've",
|
||||
"weve": "we've",
|
||||
"werent": "weren't",
|
||||
"whatll": "what'll",
|
||||
"whatre": "what're",
|
||||
"whats": "what's",
|
||||
"whatve": "what've",
|
||||
"whens": "when's",
|
||||
"whered": "where'd",
|
||||
"wheres": "where's",
|
||||
"whereve": "where've",
|
||||
"whod": "who'd",
|
||||
"whod've": "who'd've",
|
||||
"who'dve": "who'd've",
|
||||
"wholl": "who'll",
|
||||
"whos": "who's",
|
||||
"whove": "who've",
|
||||
"whyll": "why'll",
|
||||
"whyre": "why're",
|
||||
"whys": "why's",
|
||||
"wont": "won't",
|
||||
"wouldve": "would've",
|
||||
"wouldnt": "wouldn't",
|
||||
"wouldnt've": "wouldn't've",
|
||||
"wouldn'tve": "wouldn't've",
|
||||
"yall": "y'all",
|
||||
"yall'll": "y'all'll",
|
||||
"y'allll": "y'all'll",
|
||||
"yall'd've": "y'all'd've",
|
||||
"y'alld've": "y'all'd've",
|
||||
"y'all'dve": "y'all'd've",
|
||||
"youd": "you'd",
|
||||
"youd've": "you'd've",
|
||||
"you'dve": "you'd've",
|
||||
"youll": "you'll",
|
||||
"youre": "you're",
|
||||
"youve": "you've",
|
||||
}
|
||||
self.manualMap = {
|
||||
"none": "0",
|
||||
"zero": "0",
|
||||
"one": "1",
|
||||
"two": "2",
|
||||
"three": "3",
|
||||
"four": "4",
|
||||
"five": "5",
|
||||
"six": "6",
|
||||
"seven": "7",
|
||||
"eight": "8",
|
||||
"nine": "9",
|
||||
"ten": "10",
|
||||
}
|
||||
self.articles = ["a", "an", "the"]
|
||||
|
||||
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
||||
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
||||
self.punct = [
|
||||
";",
|
||||
r"/",
|
||||
"[",
|
||||
"]",
|
||||
'"',
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
"=",
|
||||
"+",
|
||||
"\\",
|
||||
"_",
|
||||
"-",
|
||||
">",
|
||||
"<",
|
||||
"@",
|
||||
"`",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
]
|
||||
def clean_text(self, text):
|
||||
text = text.replace("\n", " ").replace("\t", " ").strip()
|
||||
text = self.processPunctuation(text)
|
||||
text = self.processDigitArticle(text)
|
||||
return text
|
||||
|
||||
def evaluate_vqa_human(self, answer, gt_answers):
|
||||
'''TextVQA, VQAv2, OKVQA, vizwiz'''
|
||||
answer = answer.replace("\n", " ").replace("\t", " ").strip()
|
||||
answer = self.processPunctuation(answer)
|
||||
answer = self.processDigitArticle(answer)
|
||||
gt_answers = [self.processPunctuation(ans) for ans in gt_answers]
|
||||
gt_answers = [self.processDigitArticle(ans) for ans in gt_answers]
|
||||
|
||||
gtAcc = []
|
||||
|
||||
for idx, gtAnsDatum in enumerate(gt_answers):
|
||||
otherGTAns = gt_answers[:idx] + gt_answers[idx+1:]
|
||||
|
||||
matchingAns = [item for item in otherGTAns if answer == item]
|
||||
|
||||
acc = min(1, float(len(matchingAns)) / 3)
|
||||
gtAcc.append(acc)
|
||||
|
||||
avgGTAcc = float(sum(gtAcc)) / len(gtAcc) if gtAcc else 0
|
||||
|
||||
return avgGTAcc
|
||||
|
||||
def evaluate_anls(self, answer, gt_answers, threshold=0.5):
|
||||
'''DOcVQA, InfographicsVQA, STVQA'''
|
||||
answer = ' '.join(answer.strip().lower().split())
|
||||
if not isinstance(gt_answers, list):
|
||||
gt_answers = [gt_answers]
|
||||
gt_answers = [' '.join(gt_answer.strip().lower().split()) for gt_answer in gt_answers]
|
||||
|
||||
values = []
|
||||
for gt_answer in gt_answers:
|
||||
dist = levenshtein_distance(answer, gt_answer)
|
||||
length = max(len(answer), len(gt_answer))
|
||||
values.append(0.0 if length == 0 else float(dist) / float(length))
|
||||
|
||||
score = 1 - min(values)
|
||||
|
||||
score = 0 if score < threshold else score
|
||||
|
||||
return score
|
||||
|
||||
def processPunctuation(self, inText):
|
||||
outText = inText
|
||||
for p in self.punct:
|
||||
if (p + " " in inText or " " + p in inText) or (
|
||||
re.search(self.commaStrip, inText) != None
|
||||
):
|
||||
outText = outText.replace(p, "")
|
||||
else:
|
||||
outText = outText.replace(p, " ")
|
||||
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
||||
return outText
|
||||
|
||||
def processDigitArticle(self, inText):
|
||||
outText = []
|
||||
tempText = inText.lower().split()
|
||||
for word in tempText:
|
||||
word = self.manualMap.setdefault(word, word)
|
||||
if word not in self.articles:
|
||||
outText.append(word)
|
||||
else:
|
||||
pass
|
||||
for wordId, word in enumerate(outText):
|
||||
if word in self.contractions:
|
||||
outText[wordId] = self.contractions[word]
|
||||
outText = " ".join(outText)
|
||||
return outText
|
||||
|
||||
|
||||
def evaluate_dataset(dataset_name, answer_file_path, model_name, method = None):
|
||||
with open(answer_file_path, 'r', encoding='utf-8') as f:
|
||||
predictions = json.load(f)
|
||||
|
||||
eval = VQAEval()
|
||||
total_accuracy = 0
|
||||
num = 0
|
||||
Entry = namedtuple('Entry', ['text', 'bbox'])
|
||||
|
||||
for item in predictions:
|
||||
gt_answers = item['gt_answers']
|
||||
answer = item['answer']
|
||||
if method is not None:
|
||||
pass
|
||||
if dataset_name in ["textVQA"]:
|
||||
if num == 0:
|
||||
print(f"evaluating vqa...")
|
||||
accuracy = eval.evaluate_vqa_human(answer, gt_answers)
|
||||
elif dataset_name in ['docVQA']:
|
||||
if num == 0:
|
||||
print(f"evaluating anls...")
|
||||
accuracy = eval.evaluate_anls(answer, gt_answers)
|
||||
else:
|
||||
accuracy = eval.evaluate_has(answer, gt_answers)
|
||||
item['accuracy'] = accuracy
|
||||
|
||||
total_accuracy += accuracy
|
||||
num += 1
|
||||
|
||||
average_accuracy = total_accuracy / num
|
||||
print(f'{dataset_name}:{average_accuracy}')
|
||||
|
||||
answer_model_method_path = answer_file_path.replace('.json', f'_{model_name}_{method}.json')
|
||||
with open(answer_model_method_path, "w", encoding='utf-8') as f:
|
||||
json.dump(predictions, f, indent=4, ensure_ascii=False)
|
||||
|
||||
return average_accuracy
|
||||
|
||||
|
||||
def evaluate_VQA(
|
||||
model,
|
||||
dataset,
|
||||
model_name,
|
||||
dataset_name,
|
||||
time,
|
||||
batch_size=1,
|
||||
generate_method="interleave",
|
||||
answer_path='./answers',
|
||||
):
|
||||
print(f"answer path:{answer_path}")
|
||||
|
||||
sampler = None
|
||||
if torch.distributed.is_initialized():
|
||||
sampler=InferenceSampler(len(dataset))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn_vqa
|
||||
)
|
||||
|
||||
now_rank = torch.distributed.get_rank()
|
||||
|
||||
answer_dir = os.path.join(answer_path, model_name, time)
|
||||
os.makedirs(answer_dir, exist_ok=True)
|
||||
|
||||
image_list = []
|
||||
for item in dataset:
|
||||
image_list.append(item["image_path"])
|
||||
|
||||
predictions = []
|
||||
|
||||
for batch in tqdm(dataloader, desc="Running inference"):
|
||||
image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch
|
||||
|
||||
with torch.no_grad():
|
||||
if model_name != "minicpm":
|
||||
if model_name != "codellama":
|
||||
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
|
||||
else:
|
||||
outputs = model.generate()
|
||||
elif model_name == "minicpm":
|
||||
if generate_method == "old":
|
||||
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
|
||||
elif generate_method == "interleave":
|
||||
outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name)
|
||||
else:
|
||||
raise Exception(f"Wrong generate paradigm {generate_method}!")
|
||||
|
||||
for i in range(len(outputs)):
|
||||
answer_dict = {
|
||||
'question_id': question_ids[i],
|
||||
'question': questions[i],
|
||||
'answer': outputs[i],
|
||||
'gt_answers': gt_answers[i],
|
||||
'image_path': image_paths[i],
|
||||
'model_name': model_name,
|
||||
'question_type': question_type[i]
|
||||
}
|
||||
predictions.append(answer_dict)
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.is_initialized():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
merged_predictions = [None for _ in range(world_size)]
|
||||
torch.distributed.all_gather_object(merged_predictions, predictions)
|
||||
predictions = [_ for _ in itertools.chain.from_iterable(merged_predictions)]
|
||||
|
||||
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
|
||||
return None
|
||||
|
||||
answer_file_path = os.path.join(answer_dir, f"{dataset_name}.json")
|
||||
print(f"answer_file_path:{answer_file_path}")
|
||||
|
||||
with open(answer_file_path, "w", encoding='utf-8') as f:
|
||||
json.dump(predictions, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if dataset_name in ["docVQATest"]:
|
||||
return -1.0
|
||||
|
||||
return evaluate_dataset(answer_file_path=answer_file_path, dataset_name=dataset_name, model_name=model_name)
|
||||
96
eval_mm/vqaeval/models/MiniCPM/minicpmv.py
Normal file
96
eval_mm/vqaeval/models/MiniCPM/minicpmv.py
Normal file
@@ -0,0 +1,96 @@
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
Image.MAX_IMAGE_PIXELS = 1000000000
|
||||
|
||||
max_token = {
|
||||
'docVQA': 100,
|
||||
'textVQA': 100,
|
||||
"docVQATest": 100
|
||||
}
|
||||
|
||||
class MiniCPM_V:
|
||||
|
||||
def __init__(self, model_path, ckpt, device=None)->None:
|
||||
self.model_path = model_path
|
||||
self.ckpt = ckpt
|
||||
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).eval()
|
||||
if self.ckpt is not None:
|
||||
self.ckpt = ckpt
|
||||
self.state_dict = torch.load(self.ckpt, map_location=torch.device('cpu'))
|
||||
self.model.load_state_dict(self.state_dict)
|
||||
|
||||
self.model = self.model.to(dtype=torch.float16)
|
||||
self.model.to(device)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def generate(self, images, questions, datasetname):
|
||||
image = Image.open(images[0]).convert('RGB')
|
||||
try:
|
||||
max_new_tokens = max_token[datasetname]
|
||||
except:
|
||||
max_new_tokens = 1024
|
||||
if (datasetname == 'docVQA') or (datasetname == "docVQATest") :
|
||||
prompt = "Answer the question directly with single word." + "\n" + questions[0]
|
||||
elif (datasetname == 'textVQA') :
|
||||
prompt = "Answer the question directly with single word." + '\n'+ questions[0]
|
||||
|
||||
msgs = [{'role': 'user', 'content': prompt}]
|
||||
default_kwargs = dict(
|
||||
max_new_tokens=max_new_tokens,
|
||||
sampling=False,
|
||||
num_beams=3
|
||||
)
|
||||
res = self.model.chat(
|
||||
image=image,
|
||||
msgs=msgs,
|
||||
context=None,
|
||||
tokenizer=self.tokenizer,
|
||||
**default_kwargs
|
||||
)
|
||||
|
||||
return [res]
|
||||
|
||||
def generate_with_interleaved(self, images, questions, datasetname):
|
||||
try:
|
||||
max_new_tokens = max_token[datasetname]
|
||||
except:
|
||||
max_new_tokens = 1024
|
||||
|
||||
prompt = "Answer the question directly with single word."
|
||||
|
||||
default_kwargs = dict(
|
||||
max_new_tokens=max_new_tokens,
|
||||
sampling=False,
|
||||
num_beams=3
|
||||
)
|
||||
|
||||
content = []
|
||||
message = [
|
||||
{'type': 'text', 'value': prompt},
|
||||
{'type': 'image', 'value': images[0]},
|
||||
{'type': 'text', 'value': questions[0]}
|
||||
]
|
||||
for x in message:
|
||||
if x['type'] == 'text':
|
||||
content.append(x['value'])
|
||||
elif x['type'] == 'image':
|
||||
image = Image.open(x['value']).convert('RGB')
|
||||
content.append(image)
|
||||
msgs = [{'role': 'user', 'content': content}]
|
||||
|
||||
res = self.model.chat(
|
||||
msgs=msgs,
|
||||
context=None,
|
||||
tokenizer=self.tokenizer,
|
||||
**default_kwargs
|
||||
)
|
||||
|
||||
if isinstance(res, tuple) and len(res) > 0:
|
||||
res = res[0]
|
||||
print(f"Q: {content}, \nA: {res}")
|
||||
return [res]
|
||||
49
eval_mm/vqaeval/requirements.txt
Normal file
49
eval_mm/vqaeval/requirements.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
accelerate
|
||||
aiohttp==3.8.4
|
||||
aiosignal==1.3.1
|
||||
async-timeout==4.0.2
|
||||
attrs==22.2.0
|
||||
bitsandbytes==0.37.0
|
||||
cchardet==2.1.7
|
||||
chardet==5.1.0
|
||||
contourpy==1.0.7
|
||||
cycler==0.11.0
|
||||
filelock==3.9.0
|
||||
fonttools==4.38.0
|
||||
frozenlist==1.3.3
|
||||
huggingface-hub==0.13.4
|
||||
importlib-resources==5.12.0
|
||||
kiwisolver==1.4.4
|
||||
matplotlib==3.7.0
|
||||
multidict==6.0.4
|
||||
openai==0.27.0
|
||||
packaging==23.0
|
||||
psutil==5.9.4
|
||||
pycocotools==2.0.6
|
||||
pyparsing==3.0.9
|
||||
python-dateutil==2.8.2
|
||||
pyyaml==6.0
|
||||
regex==2022.10.31
|
||||
tokenizers==0.13.2
|
||||
tqdm==4.64.1
|
||||
transformers
|
||||
timm==0.6.13
|
||||
spacy==3.5.1
|
||||
webdataset==0.2.48
|
||||
scikit-learn==1.2.2
|
||||
scipy==1.10.1
|
||||
yarl==1.8.2
|
||||
zipp==3.14.0
|
||||
omegaconf==2.3.0
|
||||
opencv-python==4.7.0.72
|
||||
iopath==0.1.10
|
||||
decord==0.6.0
|
||||
tenacity==8.2.2
|
||||
peft
|
||||
pycocoevalcap
|
||||
sentence-transformers
|
||||
umap-learn
|
||||
notebook
|
||||
gradio==3.24.1
|
||||
gradio-client==0.0.8
|
||||
wandb
|
||||
15
eval_mm/vqaeval/shell/run_inference.sh
Normal file
15
eval_mm/vqaeval/shell/run_inference.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node=${NPROC_PER_NODE:-8} \
|
||||
--nnodes=${WORLD_SIZE:-1} \
|
||||
--node_rank=${RANK:-0} \
|
||||
--master_addr=${MASTER_ADDR:-127.0.0.1} \
|
||||
--master_port=${MASTER_PORT:-12345} \
|
||||
./eval.py \
|
||||
--model_name minicpm \
|
||||
--model_path \
|
||||
--generate_method interleave \
|
||||
--eval_textVQA \
|
||||
--eval_docVQA \
|
||||
--answer_path ./answers \
|
||||
--batchsize 1
|
||||
3
eval_mm/vqaeval/shell/run_transform.sh
Normal file
3
eval_mm/vqaeval/shell/run_transform.sh
Normal file
@@ -0,0 +1,3 @@
|
||||
python ./transform_docvqatest_for_submission.py \
|
||||
--input_file_path \
|
||||
--output_file_path
|
||||
16
eval_mm/vqaeval/transform_docvqatest_for_submission.py
Normal file
16
eval_mm/vqaeval/transform_docvqatest_for_submission.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file_path", type=str, default="", help="path to the originial output json.")
|
||||
parser.add_argument("--output_file_path", type=str, default="", help="path to where you want to save the processed json.")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.input_file_path , 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
transformed_data = [{"questionId": item["question_id"], "answer": item["answer"].replace("</s>", "")} for item in data]
|
||||
|
||||
with open(args.output_file_path, 'w') as f:
|
||||
json.dump(transformed_data, f)
|
||||
Reference in New Issue
Block a user