mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
117 lines
3.6 KiB
Python
117 lines
3.6 KiB
Python
import json
|
|
import os
|
|
import re
|
|
from torch.utils.data import Dataset
|
|
|
|
def prompt_processor(prompt):
|
|
if prompt.startswith('OCR tokens: '):
|
|
pattern = r"Question: (.*?) Short answer:"
|
|
match = re.search(pattern, prompt, re.DOTALL)
|
|
question = match.group(1)
|
|
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
|
|
if prompt.startswith('Reference OCR token:'):
|
|
question = prompt.split('\n')[1]
|
|
else:
|
|
question = prompt.split('\n')[0]
|
|
elif len(prompt.split('\n')) == 2:
|
|
question = prompt.split('\n')[0]
|
|
else:
|
|
assert False
|
|
|
|
return question.lower()
|
|
|
|
class textVQADataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_dir="./downloads/TextVQA/train_images",
|
|
ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json",
|
|
):
|
|
self.data = json.load(open(ann_path, "r"))["data"]
|
|
self.image_dir = image_dir
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
question = self.data[idx]['question']
|
|
answers = self.data[idx]['answers']
|
|
img_id = self.data[idx]['image_id']
|
|
qid = self.data[idx]['question_id']
|
|
img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
|
|
|
|
item = {
|
|
"question_id": qid,
|
|
"image_path": img_path,
|
|
"question": question,
|
|
"gt_answers": answers
|
|
}
|
|
|
|
return item
|
|
|
|
class docVQADataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_dir= "./downloads/DocVQA/spdocvqa_images",
|
|
ann_path= "./downloads/DocVQA/val_v1.0_withQT.json",
|
|
ocr_token_path=None
|
|
):
|
|
|
|
self.data = json.load(open(ann_path, "r"))["data"]
|
|
self.image_dir = image_dir
|
|
self.ann_path = ann_path
|
|
if ocr_token_path:
|
|
self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]}
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
question_id = self.data[idx]['questionId']
|
|
relative_img_path = self.data[idx]['image']
|
|
corrected_relative_img_path = relative_img_path.replace("documents", "images")
|
|
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
|
|
question = self.data[idx]['question']
|
|
answers = self.data[idx]['answers']
|
|
|
|
question_type = self.data[idx]['question_types']
|
|
|
|
return {
|
|
"question_id": question_id,
|
|
"image_path": img_path,
|
|
"question": question,
|
|
"gt_answers": answers,
|
|
'question_type': question_type,
|
|
}
|
|
|
|
|
|
class docVQATESTDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
image_dir= "./downloads/DocVQA/spdocvqa_images",
|
|
ann_path= "./downloads/DocVQA/test_v1.0.json",
|
|
ocr_token_path=None
|
|
):
|
|
|
|
self.data = json.load(open(ann_path, "r"))["data"]
|
|
self.image_dir = image_dir
|
|
self.ann_path = ann_path
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
question_id = self.data[idx]['questionId']
|
|
relative_img_path = self.data[idx]['image']
|
|
corrected_relative_img_path = relative_img_path.replace("documents", "images")
|
|
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
|
|
question = self.data[idx]['question']
|
|
|
|
|
|
return {
|
|
"question_id": question_id,
|
|
"image_path": img_path,
|
|
"question": question,
|
|
"gt_answers": "",
|
|
'question_type': "",
|
|
}
|