mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
Modify eval_mm for MiniCPM-o 2.6
This commit is contained in:
@@ -370,8 +370,6 @@ def evaluate_VQA(
|
||||
generate_method="interleave",
|
||||
answer_path='./answers',
|
||||
):
|
||||
print(f"answer path:{answer_path}")
|
||||
|
||||
sampler = None
|
||||
if torch.distributed.is_initialized():
|
||||
sampler=InferenceSampler(len(dataset))
|
||||
@@ -383,8 +381,6 @@ def evaluate_VQA(
|
||||
collate_fn=collate_fn_vqa
|
||||
)
|
||||
|
||||
now_rank = torch.distributed.get_rank()
|
||||
|
||||
answer_dir = os.path.join(answer_path, model_name, time)
|
||||
os.makedirs(answer_dir, exist_ok=True)
|
||||
|
||||
@@ -395,21 +391,15 @@ def evaluate_VQA(
|
||||
predictions = []
|
||||
|
||||
for batch in tqdm(dataloader, desc="Running inference"):
|
||||
image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch
|
||||
image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch
|
||||
|
||||
with torch.no_grad():
|
||||
if model_name != "minicpm":
|
||||
if model_name != "codellama":
|
||||
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
|
||||
else:
|
||||
outputs = model.generate()
|
||||
elif model_name == "minicpm":
|
||||
if generate_method == "old":
|
||||
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
|
||||
elif generate_method == "interleave":
|
||||
outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name)
|
||||
else:
|
||||
raise Exception(f"Wrong generate paradigm {generate_method}!")
|
||||
if generate_method == "old":
|
||||
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
|
||||
elif generate_method == "interleave":
|
||||
outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name)
|
||||
else:
|
||||
raise Exception(f"Wrong generate paradigm {generate_method}!")
|
||||
|
||||
for i in range(len(outputs)):
|
||||
answer_dict = {
|
||||
|
||||
Reference in New Issue
Block a user