Modify eval_mm for MiniCPM-o 2.6

This commit is contained in:
Poppy Xu
2025-01-21 15:34:54 +08:00
parent ec68cefc17
commit d8f382e157
82 changed files with 14279 additions and 843 deletions

View File

@@ -0,0 +1,59 @@
# CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy
## Introduction
Please refer to our [GitHub](https://github.com/AlibabaResearch/AdvancedLiterateMachinery/tree/main/Benchmarks/CC-OCR) for more information.
## Running Scripts
Once the environment is ready, execute the following script from the root directory of VLMEvalKit
to perform inference and evaluation tasks in batch.
```shell
MODEL_NAME="QwenVLMax"
OUTPUT_DIR="/your/path/to/output_dir"
SUB_OUTPUT_DIR=${OUTPUT_DIR}/multi_scene_ocr
python run.py --data CCOCR_MultiSceneOcr_Cord CCOCR_MultiSceneOcr_Funsd CCOCR_MultiSceneOcr_Iam CCOCR_MultiSceneOcr_ZhDoc CCOCR_MultiSceneOcr_ZhHandwriting CCOCR_MultiSceneOcr_Hieragent CCOCR_MultiSceneOcr_Ic15 CCOCR_MultiSceneOcr_Inversetext CCOCR_MultiSceneOcr_Totaltext CCOCR_MultiSceneOcr_ZhScene CCOCR_MultiSceneOcr_UgcLaion CCOCR_MultiSceneOcr_ZhDense CCOCR_MultiSceneOcr_ZhVertical --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
SUB_OUTPUT_DIR=${OUTPUT_DIR}/multi_lan_ocr
python run.py --data CCOCR_MultiLanOcr_Arabic CCOCR_MultiLanOcr_French CCOCR_MultiLanOcr_German CCOCR_MultiLanOcr_Italian CCOCR_MultiLanOcr_Japanese CCOCR_MultiLanOcr_Korean CCOCR_MultiLanOcr_Portuguese CCOCR_MultiLanOcr_Russian CCOCR_MultiLanOcr_Spanish CCOCR_MultiLanOcr_Vietnamese --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
SUB_OUTPUT_DIR=${OUTPUT_DIR}/doc_parsing
python run.py --data CCOCR_DocParsing_DocPhotoChn CCOCR_DocParsing_DocPhotoEng CCOCR_DocParsing_DocScanChn CCOCR_DocParsing_DocScanEng CCOCR_DocParsing_TablePhotoChn CCOCR_DocParsing_TablePhotoEng CCOCR_DocParsing_TableScanChn CCOCR_DocParsing_TableScanEng CCOCR_DocParsing_MolecularHandwriting CCOCR_DocParsing_FormulaHandwriting --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
SUB_OUTPUT_DIR=${OUTPUT_DIR}/kie
python run.py --data CCOCR_Kie_Sroie2019Word CCOCR_Kie_Cord CCOCR_Kie_EphoieScut CCOCR_Kie_Poie CCOCR_Kie_ColdSibr CCOCR_Kie_ColdCell --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
```
## Example Output
The evaluation results will be saved in `${SUB_OUTPUT_DIR}/summary.md`. For example, for the KIE subset,
the output is as follows:
| exp_name(f1_score) | COLD_CELL | COLD_SIBR | CORD | EPHOIE_SCUT | POIE | sroie2019_word | summary |
|:-------------------|------------:|------------:|-------:|--------------:|-------:|-----------------:|----------:|
| QwenVLMax | 81.01 | 72.46 | 69.33 | 71.2 | 60.85 | 76.37 | 71.87 |
## Citation
If you find our work helpful, feel free to give us a cite.
```
@misc{yang2024ccocr,
title={CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy},
author={Zhibo Yang and Jun Tang and Zhaohai Li and Pengfei Wang and Jianqiang Wan and Humen Zhong and Xuejing Liu and Mingkun Yang and Peng Wang and Shuai Bai and LianWen Jin and Junyang Lin},
year={2024},
eprint={2412.02210},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.02210},
}
```
## Contact Us
If you have any questions, feel free to send an email to: wpf272043@alibaba-inc.com or xixing.tj@alibaba-inc.com

View File

@@ -0,0 +1,12 @@
from .kie_evaluator import KieEvaluator
from .doc_parsing_evaluator import ParsingEvaluator
from .ocr_evaluator import OcrEvaluator
from .common import summary
evaluator_map_info = {
"kie": KieEvaluator("kie"),
"doc_parsing": ParsingEvaluator("doc_parsing"),
"multi_lan_ocr": OcrEvaluator("multi_lan_ocr"),
"multi_scene_ocr": OcrEvaluator("multi_scene_ocr")
}

View File

@@ -0,0 +1,222 @@
import os
import json
import time
import sys
from abc import abstractmethod
from tabulate import tabulate
def pick_response_text(json_path):
"""
"""
try:
with open(json_path, "r") as f:
json_data = json.load(f)
except Exception as e:
print("--> file error: msg: {}, path: {}".format(e, json_path))
return None
for required_key in ["model_name", "response"]:
if required_key not in json_data:
print("--> required key not exists, name: {}, path: {}".format(required_key, json_path))
return None
model_name = json_data["model_name"]
model_response = json_data["response"]
response_text = None
if model_name.startswith("gpt") or model_name.startswith("o1"):
response_text = model_response.get("data", {}).get("response", {}).get("choices", [{}])[0].get("message", {}).get("content", None) # noqa: E501
elif model_name.startswith("local_"):
response_text = model_response
else:
if model_name.startswith("claude"):
content_list = model_response.get("content", None)
elif model_name.startswith("gemini"):
content_list = model_response.get("candidates", [{}])[0].get("content", {}).get("parts", None)
elif model_name.startswith("qwen"):
content_list = model_response.get("output", {}).get("choices", [{}])[0].get("message", {}).get("content", None) # noqa: E501
else:
raise NotImplementedError("The pick_response_text NOT implemented for model: {}".format(model_name))
if isinstance(content_list, list) and len(content_list) > 0:
response_text = content_list[0].get("text", None)
if response_text is None:
print("--> [error][{}] text pick error, path: {}".format(model_name, json_path))
return response_text
def load_response_from_dir(res_dir):
"""
"""
response_info = {}
for file_name in os.listdir(res_dir):
file_path = os.path.abspath(os.path.join(res_dir, file_name))
if not file_name.endswith(".json"):
print("--> skip: result file should be a json: but got: {}".format(file_path))
continue
response_text = pick_response_text(file_path)
if response_text is None:
continue
file_name_wo_ext, ext = os.path.splitext(file_name)
response_info[file_name_wo_ext] = response_text
return response_info
class BaseMetric(object):
""" BaseMetric """
""" OCRMetric """
def __init__(self, group_name, **kwargs):
self.group_name = group_name
self.kwargs = kwargs
def response_post_func(self, response_text, **kwargs):
return response_text
@abstractmethod
# Given the prediction and gt, return the evaluation results in the format of a dictionary
# results should contain a 'summary' key, for example:
# {
# "summary": {
# "f1-score": 99.99,
# "metric_name": "metric_value" # used for summaryonly metric info could be placed in this dict.
# },
# "your other info": "xxx"
# }
def evaluate(self, response_info, gt_info, normalize_func=None, **kwargs):
pass
def __call__(self, pdt_res_dir, gt_info, with_response_ratio=True, **kwargs):
if isinstance(pdt_res_dir, dict):
raw_response_info = pdt_res_dir
elif os.path.exists(pdt_res_dir) and os.path.isdir(pdt_res_dir):
raw_response_info = load_response_from_dir(pdt_res_dir)
else:
return ValueError("invalid input: response dict or folder are required, but got {}".format(pdt_res_dir))
post_error_list, response_info = [], {}
response_error_list = list(gt_info.keys() - raw_response_info.keys())
for file_name, single_pdt_str in raw_response_info.items():
single_pdt_str = self.response_post_func(single_pdt_str, **kwargs)
if single_pdt_str is None:
post_error_list.append(file_name)
continue
response_info[file_name] = single_pdt_str
meta_info = {
"gt_total_num": len(gt_info), "pdt_total_num": len(response_info),
"post_error_list": post_error_list, "response_error_list": response_error_list,
}
eval_info = self.evaluate(response_info, gt_info, **kwargs)
# add response_success_ratio
if "summary" in eval_info and with_response_ratio:
success_ratio = (len(response_info) + len(post_error_list)) / (len(gt_info) + 1e-9)
eval_info["summary"].update({"response_success_ratio": success_ratio})
return meta_info, eval_info
def summary(index_path, exp_dir_base, is_weighted_sum=False):
"""
"""
with open(index_path, "r") as f:
data_list = json.load(f)
all_data_info = {}
for data_info_item in data_list:
data_name = data_info_item["dataset"]
if not data_info_item.get("release", True):
continue
all_data_info[data_name] = data_info_item
dataset_list = list(all_data_info.keys())
summary_path = summary_multi_exp(exp_dir_base, dataset_list, is_weighted_sum=is_weighted_sum)
return summary_path
def summary_multi_exp(exp_dir_base, dataset_list=None, is_weighted_sum=False):
"""
"""
if dataset_list is None:
all_dataset_name = []
for exp_name in os.listdir(exp_dir_base):
dir_status_path = os.path.join(exp_dir_base, exp_name, "status.json")
if not os.path.exists(dir_status_path):
continue
with open(dir_status_path, "r") as f:
data_status_info = json.load(f)
all_dataset_name.extend(data_status_info.keys())
dataset_list = sorted(set(all_dataset_name))
# summary main code
all_evaluate_info, _ = {}, 0
for exp_name in os.listdir(exp_dir_base):
dir_status_path = os.path.join(exp_dir_base, exp_name, "status.json")
if not os.path.exists(dir_status_path):
print("--> skip: status.json not exist: {}".format(dir_status_path))
continue
with open(dir_status_path, "r") as f:
all_status_info = json.load(f)
for data_name in dataset_list:
total_num = all_status_info.get(data_name, {}).get("config", {}).get("num", "-1")
summary_info = all_status_info.get(data_name, {}).get("evaluation", {}).get("summary", {})
for metric_name, metric_value in summary_info.items():
if metric_name not in all_evaluate_info:
all_evaluate_info[metric_name] = {}
if exp_name not in all_evaluate_info[metric_name]:
all_evaluate_info[metric_name][exp_name] = {}
all_evaluate_info[metric_name][exp_name][data_name] = (metric_value, total_num)
all_table_md = []
for metric_name, metric_info in all_evaluate_info.items():
formatted_time = time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time()))
summary_line_list = []
summary_key_name = "summary(weighted)" if is_weighted_sum else "summary"
summary_head = [f"exp_name({metric_name}_{formatted_time})"] + dataset_list + [summary_key_name]
for exp_name, data_eval_info in metric_info.items():
summary_line = [exp_name, ]
all_metric_value = 0
is_summary_valid, all_total_num, all_weighted_metric = True, 0, 0
for data_name in dataset_list:
metric_value, total_num = data_eval_info.get(data_name, ("-1", "-1"))
summary_line.append("{:.2f}".format(float(metric_value) * 100))
if str(metric_value) == "-1" or str(metric_value) == "-1":
is_summary_valid = False
continue
all_total_num += float(total_num)
all_weighted_metric += float(total_num) * float(metric_value)
all_metric_value += float(metric_value)
summary_value_valid = ((all_weighted_metric / (all_total_num + 1e-9)) * 100) if is_weighted_sum \
else (all_metric_value / (len(dataset_list) + 1e-9) * 100)
summary_value = "-" if not is_summary_valid else "{:.2f}".format(summary_value_valid)
summary_line.append(summary_value)
summary_line_list.append(summary_line)
md_table_info = tabulate(summary_line_list, headers=summary_head, tablefmt='pipe')
all_table_md.append(md_table_info)
print("\n\n".join(all_table_md))
summary_path = os.path.abspath(os.path.join(exp_dir_base, "summary.md"))
with open(summary_path, "w") as f:
f.write("\n\n".join(all_table_md))
return summary_path
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Usage: python {} exp_base_dir".format(__file__))
exit(-1)
else:
print('--> info: {}'.format(sys.argv))
exp_base_dir = sys.argv[1]
summary_path = summary_multi_exp(exp_base_dir, dataset_list=None, is_weighted_sum=False)
print("--> info: summary saved at : {}".format(summary_path))
print("happy coding.")

View File

@@ -0,0 +1,256 @@
import nltk
import re
from tqdm import tqdm
from collections import deque
from apted.helpers import Tree
from apted import APTED, Config
# local import
from .common import BaseMetric
# 移除指定的LaTeX命令
patterns = [
r'\\documentclass\{.*?\}',
r'\\usepackage\[.*?\]\{.*?\}',
r'\\usepackage\{.*?\}',
r'\\geometry\{.*?\}',
r'\\begin\{document\}',
r'\\end\{document\}',
r'\\noindent'
]
class TableTree(Tree):
"""
# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
# License: Apache 2.0 License.
"""
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.children = list(children)
def bracket(self):
"""Show tree using brackets notation"""
if self.tag == "td":
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
self.tag,
self.colspan,
self.rowspan,
self.content,
)
else:
result = '"tag": %s' % self.tag
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
class CustomConfig(Config):
"""
# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
# License: Apache 2.0 License.
"""
def rename(self, node1, node2):
"""Compares attributes of trees"""
# print(node1.tag)
if (
(node1.tag != node2.tag)
or (node1.colspan != node2.colspan)
or (node1.rowspan != node2.rowspan)
):
return 1.0
if node1.tag == "td":
if node1.content or node2.content:
return nltk.edit_distance(node1.content, node2.content) / max(len(node1.content), len(node2.content))
return 0.0
class TEDS(object):
"""Tree Edit Distance basead Similarity
# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
# License: Apache 2.0 License.
"""
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (
n_jobs >= 1
), "n_jobs must be an integer greather than 1"
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []
def tokenize(self, node):
"""Tokenizes table cells"""
self.__tokens__.append("<%s>" % node.tag)
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
if node.tag != "unk":
self.__tokens__.append("</%s>" % node.tag)
if node.tag != "td" and node.tail is not None:
self.__tokens__ += list(node.tail)
def load_html_tree(self, node, parent=None):
"""Converts HTML tree to the format required by apted"""
global __tokens__
if node.tag == "td":
if self.structure_only:
cell = []
else:
self.__tokens__ = []
self.tokenize(node)
cell = self.__tokens__[1:-1].copy()
new_node = TableTree(
node.tag,
int(node.attrib.get("colspan", "1")),
int(node.attrib.get("rowspan", "1")),
cell,
*deque(),
)
else:
new_node = TableTree(node.tag, None, None, None, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != "td":
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node
def evaluate(self, pred, true):
"""Computes TEDS score between the prediction and the ground truth of a
given sample
"""
# try_import("lxml")
from lxml import etree, html
if (not pred) or (not true):
return 0.0
parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
if pred.xpath("body/table") and true.xpath("body/table"):
pred = pred.xpath("body/table")[0]
true = true.xpath("body/table")[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
n_nodes_pred = len(pred.xpath(".//*"))
n_nodes_true = len(true.xpath(".//*"))
n_nodes = max(n_nodes_pred, n_nodes_true)
tree_pred = self.load_html_tree(pred)
tree_true = self.load_html_tree(true)
distance = APTED(
tree_pred, tree_true, CustomConfig()
).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0
class ParsingEvaluator(BaseMetric):
def response_post_func(self, response_text, **kwargs):
return response_text
def evaluate(self, response_info, gt_info, **kwargs):
op = kwargs['op']
if op == 'doc':
score = self.eval_doc(response_info, gt_info)
elif op == 'table':
score = self.eval_table(response_info, gt_info)
elif op in ['molecular', "formula"]:
score = self.eval_formula(response_info, gt_info, op_name=op)
else:
raise ValueError(f'doc parsing unsupported op: {op}')
# summary info
eval_info = {"summary": {"score": score}}
return eval_info
def eval_doc(self, response_info, gt_info):
results = []
for img_name, gt in tqdm(gt_info.items()):
if img_name not in response_info:
results.append(0)
continue
pred = response_info[img_name]
for pattern in patterns:
pred = re.sub(pattern, '', pred)
try:
pred = pred.split('```')[1]
except:
pass
pred = pred.replace('```latex', '')
pred = pred.replace('```', '')
pred = pred.replace(' ', '').replace('\n', '')
gt = gt.replace(' ', '').replace('\n', '')
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
results.append(1 - edit_dist)
score = sum(results) / len(results)
return score
def eval_table(self, response_info, gt_info):
teds = TEDS(structure_only=False, n_jobs=1)
results = []
for img_name, gt in tqdm(gt_info.items()):
if img_name not in response_info:
results.append(0)
continue
pred = response_info[img_name]
for pattern in patterns:
pred = re.sub(pattern, '', pred)
try:
pred = pred.split('```html')[1]
except:
pass
pred = pred.replace('```', '')
pred = pred.replace(' ', '').replace('\n', '').replace('', ',')
gt = gt.replace(' ', '').replace('\n', '')
pred_html = '<html><body>{}</body></html>'.format(pred)
gt_html = '<html><body>{}</body></html>'.format(gt)
results.append(teds.evaluate(pred_html, gt_html))
score = sum(results) / len(results)
return score
def eval_formula(self, response_info, gt_info, op_name='formula'):
results = []
for img_name, gt in tqdm(gt_info.items()):
if img_name not in response_info:
results.append(0)
continue
pred = response_info[img_name]
if op_name == 'formula':
pred = pred.replace("\n", " ").replace("```latex", "").replace("```", "").replace("\t", " ").replace(" ", "") # noqa: E501
gt = gt.replace(" ", "")
elif op_name == 'molecular':
pred = pred.replace("\n", "").replace(" ", "").replace("<smiles>", "").replace("</smiles>", "")
gt = gt.replace(" ", "")
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
results.append(1 - edit_dist)
score = sum(results) / len(results)
return score
if __name__ == '__main__':
pass

View File

@@ -0,0 +1,385 @@
"""
Donut
Copyright (c) 2022-present NAVER Corp.
MIT License
"""
import json
import os
import sys
import re
import time
from typing import Any, Dict, List, Tuple, Union
import zss
from zss import Node
from collections import Counter
from nltk import edit_distance
# local import
from .common import BaseMetric
def flatten(data: dict):
"""
Convert Dictionary into Non-nested Dictionary
Example:
input(dict)
{
"menu": [
{"name" : ["cake"], "count" : ["2"]},
{"name" : ["juice"], "count" : ["1"]},
]
}
output(list)
[
("menu.name", "cake"),
("menu.count", "2"),
("menu.name", "juice"),
("menu.count", "1"),
]
"""
flatten_data = list()
def _flatten(value, key=""):
if type(value) is dict:
for child_key, child_value in value.items():
_flatten(child_value, f"{key}.{child_key}" if key else child_key)
elif type(value) is list:
for value_item in value:
_flatten(value_item, key)
else:
flatten_data.append((key, value))
_flatten(data)
return flatten_data
def update_cost(node1: Node, node2: Node):
"""
Update cost for tree edit distance.
If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored).
If one of them is leaf node, cost is length of string in leaf node + 1.
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
"""
label1 = node1.label
label2 = node2.label
label1_leaf = "<leaf>" in label1
label2_leaf = "<leaf>" in label2
if label1_leaf and label2_leaf:
return edit_distance(label1.replace("<leaf>", ""), label2.replace("<leaf>", ""))
elif not label1_leaf and label2_leaf:
return 1 + len(label2.replace("<leaf>", ""))
elif label1_leaf and not label2_leaf:
return 1 + len(label1.replace("<leaf>", ""))
else:
return int(label1 != label2)
def insert_and_remove_cost(node: Node):
"""
Insert and remove cost for tree edit distance.
If leaf node, cost is length of label name.
Otherwise, 1
"""
label = node.label
if "<leaf>" in label:
return len(label.replace("<leaf>", ""))
else:
return 1
def normalize_dict(data: Union[Dict, List, Any]):
"""
Sort by value, while iterate over element if data is list
"""
# if not data:
# return {}
if isinstance(data, dict):
new_data = dict()
for key in sorted(data.keys(), key=lambda k: (len(k), k)):
value = normalize_dict(data[key])
if value:
if not isinstance(value, list):
value = [value]
new_data[key] = value
elif isinstance(data, list):
if all(isinstance(item, dict) for item in data):
new_data = []
for item in data:
item = normalize_dict(item)
if item:
new_data.append(item)
else:
new_data = [str(item).strip() for item in data if type(item) in {str, int, float} and str(item).strip()]
else:
new_data = [str(data).strip()]
return new_data
def cal_f1_all(preds, answers):
"""
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives,
false negatives and false positives
"""
metric_info, error_info = {}, {}
total_tp, total_fn_or_fp = 0, 0
for file_name, answer in answers.items():
sample_error_info = {"fp": [], "fn": [], "tp": []}
pred = preds.get(file_name, {})
pred, answer = flatten(normalize_dict(pred)), flatten(normalize_dict(answer))
for field in pred:
field_name = field[0]
if field_name not in metric_info:
metric_info[field_name] = {"total_tp": 0, "total_fn_or_fp": 0}
if field in answer:
total_tp += 1
metric_info[field_name]["total_tp"] += 1
sample_error_info["tp"].append(field)
answer.remove(field)
else:
total_fn_or_fp += 1
metric_info[field_name]["total_fn_or_fp"] += 1
sample_error_info["fp"].append(field)
total_fn_or_fp += len(answer)
for field in answer:
field_name = field[0]
if field_name not in metric_info:
metric_info[field_name] = {"total_tp": 0, "total_fn_or_fp": 0}
metric_info[field_name]["total_fn_or_fp"] += 1
sample_error_info["fn"].append(field)
sample_error_num = sum([len(v) for k, v in sample_error_info.items() if k != "tp"])
if sample_error_num > 0:
sample_error_info["error_num"] = sample_error_num
error_class_list = ["counter_" + x[0] for x in (sample_error_info["fn"] + sample_error_info["fp"])]
counter = Counter(error_class_list)
sample_error_info["error_info"] = dict(counter)
error_info[file_name] = sample_error_info
# summary
for field_name, field_info in metric_info.items():
field_tp, field_fn_or_fp = field_info["total_tp"], field_info["total_fn_or_fp"]
metric_info[field_name]["acc"] = field_tp / (field_tp + field_fn_or_fp / 2 + 1e-6)
print("donut_evaluator: total_tp: {}, total_fn_or_fp: {}, ptd_num: {}, gt_num: {}".format(total_tp, total_fn_or_fp,
len(preds), len(answers)))
error_info = {k: v for k, v in
sorted(error_info.items(), key=lambda item: item[1].get("error_num", 0), reverse=True)}
metric_info = {k: v for k, v in
sorted(metric_info.items(), key=lambda item: item[1].get("total_fn_or_fp", 0), reverse=True)}
return total_tp / (total_tp + total_fn_or_fp / 2 + 1e-6), metric_info, error_info
def construct_tree_from_dict(data: Union[Dict, List], node_name: str = None):
"""
Convert Dictionary into Tree
Example:
input(dict)
{
"menu": [
{"name" : ["cake"], "count" : ["2"]},
{"name" : ["juice"], "count" : ["1"]},
]
}
output(tree)
<root>
|
menu
/ \
<subtree> <subtree>
/ | | \
name count name count
/ | | \
<leaf>cake <leaf>2 <leaf>juice <leaf>1
"""
if node_name is None:
node_name = "<root>"
node = Node(node_name)
if isinstance(data, dict):
for key, value in data.items():
kid_node = construct_tree_from_dict(value, key)
node.addkid(kid_node)
elif isinstance(data, list):
if all(isinstance(item, dict) for item in data):
for item in data:
kid_node = construct_tree_from_dict(
item,
"<subtree>",
)
node.addkid(kid_node)
else:
for item in data:
node.addkid(Node(f"<leaf>{item}"))
else:
raise Exception(data, node_name)
return node
def cal_acc(pred: dict, answer: dict):
"""
Calculate normalized tree edit distance(nTED) based accuracy.
1) Construct tree from dict,
2) Get tree distance with insert/remove/update cost,
3) Divide distance with GT tree size (i.e., nTED),
4) Calculate nTED based accuracy. (= max(1 - nTED, 0 ).
"""
pred = construct_tree_from_dict(normalize_dict(pred))
answer = construct_tree_from_dict(normalize_dict(answer))
val1 = zss.distance(
pred,
answer,
get_children=zss.Node.get_children,
insert_cost=insert_and_remove_cost,
remove_cost=insert_and_remove_cost,
update_cost=update_cost,
return_operations=False,
)
val2 = zss.distance(
construct_tree_from_dict(normalize_dict({})),
answer,
get_children=zss.Node.get_children,
insert_cost=insert_and_remove_cost,
remove_cost=insert_and_remove_cost,
update_cost=update_cost,
return_operations=False,
)
return max(0, 1 - val1 / val2)
def cal_acc_all(pred_info, answer_info):
acc_info, error_info = {}, {}
for file_name, answer in answer_info.items():
# if file_name not in pred_info:
# print("---> error: pdt not found: {}".format(file_name))
# continue
pred = pred_info.get(file_name, {})
acc = cal_acc(pred, answer)
acc_info[file_name] = acc
if acc < 1.0:
error_info[file_name] = {"acc": acc, "pred": pred, "answer": answer}
error_info = {k: v for k, v in sorted(error_info.items(), key=lambda item: item[1].get("acc", 0))}
acc_averge = sum(list(acc_info.values())) / (len(acc_info) + 1e-6)
return acc_averge, error_info
def normalize_values_of_nested_dict(d, normalize_func):
"""
"""
if isinstance(d, dict):
return {k: normalize_values_of_nested_dict(v, normalize_func) for k, v in d.items()}
elif isinstance(d, list):
return [normalize_values_of_nested_dict(x, normalize_func) if isinstance(x, dict) else x for x in d]
elif isinstance(d, str):
return normalize_func(d)
else:
return d
def eval_donut(pdt_info, gt_info, normalize_func=None, data_name=None):
"""
"""
if normalize_func is not None:
print("--> info: normalize_func executed.")
pdt_info = normalize_values_of_nested_dict(pdt_info, normalize_func)
gt_info = normalize_values_of_nested_dict(gt_info, normalize_func)
f1_score, class_eval_info, error_info = cal_f1_all(pdt_info, gt_info)
acc_average, acc_error_info = cal_acc_all(pdt_info, gt_info)
eval_info = {"f1_score": f1_score, "acc": acc_average, "class_f1_score": class_eval_info,
"f1_error_info": error_info, "acc_error_info": acc_error_info}
print(data_name, "f1_score", f1_score, "acc", acc_average)
return eval_info
def post_process_to_json(qwen_info_str, file_name=None):
try:
if "```json" in qwen_info_str:
if "```" not in qwen_info_str:
qwen_info_str += "```"
qwen_info_group = re.search(r'```json(.*?)```', qwen_info_str, re.DOTALL)
json_str = qwen_info_group.group(1).strip().replace("\n", "")
else:
json_str = qwen_info_str.strip().replace("\n", "")
json_data = json.loads(json_str)
return json_data
except Exception as err: # noqa: F841
return None
def fullwidth_to_halfwidth(text):
# 全角转半角
result = ''
for char in text:
code_point = ord(char)
# 全角空格直接转化
if code_point == 0x3000:
code_point = 0x0020
# 其他全角字符(除空格)转换为半角
elif 0xFF01 <= code_point <= 0xFF5E:
code_point -= 0xFEE0
result += chr(code_point)
result = result.replace("", ",")
return result
def remove_unnecessary_spaces(text):
# 去掉中文字符之间的空格
text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])', '', text)
# 去掉中文和英文、数字之间的空格
text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[a-zA-Z0-9])', '', text)
text = re.sub(r'(?<=[a-zA-Z0-9])\s+(?=[\u4e00-\u9fff])', '', text)
# 去掉符号前的不必要空格,保留符号后的一个空格
text = re.sub(r'(?<![0-9])\s*([,.!?:;])\s*', r'\1 ', text) # 非数字前后的符号
# 在数字和英文之间添加空格
text = re.sub(r'(?<=[0-9])(?=[a-zA-Z])', ' ', text)
text = re.sub(r'(?<=[a-zA-Z])(?=[0-9])', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text
class KieEvaluator(BaseMetric):
def response_post_func(self, response_text, **kwargs):
response_text = post_process_to_json(response_text, file_name=kwargs.get('file_name', None))
return response_text
def normalize_func(self, text, **kwargs):
halfwidth_text = fullwidth_to_halfwidth(str(text))
cleaned_text = remove_unnecessary_spaces(halfwidth_text)
return cleaned_text
def evaluate(self, response_info, gt_info, **kwargs):
"""
response_info: dict: {"file_name_1": response, "file_name_2": gt}
gt_info: dict: {"file_name_1": gt, "file_name_2": gt}
kwargs: dataset index config: {'dataset': 'kie_benchmark_POIE', 'group': 'kie', 'op': 'poie', 'num': 250}
"""
# gt should be a dict for kie task, fix for VLMEvalKit
for image_name, label_content in gt_info.items():
if isinstance(label_content, str):
gt_info[image_name] = json.loads(label_content)
response_info = normalize_values_of_nested_dict(response_info, self.normalize_func)
gt_info = normalize_values_of_nested_dict(gt_info, self.normalize_func)
f1_score, class_eval_info, error_info = cal_f1_all(response_info, gt_info)
acc_average, acc_error_info = cal_acc_all(response_info, gt_info)
# summary info
summary_info = {"f1_score": f1_score, "acc": acc_average}
eval_info = {"summary": summary_info, "class_f1_score": class_eval_info,
"f1_error_info": error_info, "acc_error_info": acc_error_info}
return eval_info
if __name__ == '__main__':
pass

View File

@@ -0,0 +1,106 @@
import os
import sys
import json
import re
from collections import Counter
# local import
from .common import BaseMetric
def token_normalize(token_text, is_lower=False, is_alphanum_only=False):
"""
"""
if is_lower:
token_text = token_text.lower()
if is_alphanum_only:
token_text = re.sub('[^A-Za-z0-9]+', '', token_text)
return token_text
def text_normalize_and_tokenize(text, is_keep_blank=True, is_lower=True, is_alphanum_only=False):
text = text.replace("\t", " ").replace("\n", " ").replace("###", "").replace("***", "")
text = re.sub(r'\s+', ' ', text)
if not is_keep_blank:
text = text.replace(" ", "")
text_tokens = text.split(" ") if is_keep_blank else list(text)
text_token_normalized = [token_normalize(t, is_lower, is_alphanum_only) for t in text_tokens]
text_token_normalized = [x for x in text_token_normalized if len(x) > 0]
return text_token_normalized
def evaluate_single_sample(gts, preds):
right_num = 0
gt_counter_info = dict(Counter(gts))
pdt_counter_info = dict(Counter(preds))
for gt_token, gt_count in gt_counter_info.items():
pred_count = pdt_counter_info.get(gt_token, 0)
right_num += min(gt_count, pred_count)
return right_num
def calculate_metrics(response_info, gt_info, is_verbose=False):
"""
"""
macro_recall_list, macro_precision_list, macro_f1_list = [], [], []
total_gt_num, total_pred_num, total_right_num = 0, 0, 0
for file_name, fullbox_gts in gt_info.items():
fullbox_preds = response_info.get(file_name, [])
right_num = evaluate_single_sample(fullbox_gts, fullbox_preds)
total_right_num += right_num
total_gt_num += len(fullbox_gts)
total_pred_num += len(fullbox_preds)
macro_recall = right_num / (len(fullbox_gts) + 1e-9)
macro_precision = right_num / (len(fullbox_preds) + 1e-9)
macro_f1 = 2 * macro_recall * macro_precision / (macro_recall + macro_precision + 1e-9)
macro_recall_list.append(macro_recall)
macro_precision_list.append(macro_precision)
macro_f1_list.append(macro_f1)
# marco
final_macro_recall = sum(macro_recall_list) / (len(macro_recall_list) + 1e-9)
final_macro_precision = sum(macro_precision_list) / (len(macro_precision_list) + 1e-9)
final_macro_f1 = sum(macro_f1_list) / (len(macro_f1_list) + 1e-9)
# micro
recall_acc = total_right_num / (total_gt_num + 1e-9)
preci_acc = total_right_num / (total_pred_num + 1e-9)
hmean = 2 * recall_acc * preci_acc / (recall_acc + preci_acc + 1e-9)
vbs_eval_result = {
'macro_recall': final_macro_recall, 'macro_precision': final_macro_precision, 'macro_f1_score': final_macro_f1,
'micro_recall': recall_acc, 'micro_precision': preci_acc, 'mirco_f1_score': hmean
}
eval_result = vbs_eval_result if is_verbose else {'macro_f1_score': final_macro_f1, 'mirco_f1_score': hmean}
return eval_result
class OcrEvaluator(BaseMetric):
def response_post_func(self, response_text, **kwargs):
return response_text
def evaluate(self, response_info, gt_info, **kwargs):
# hard code here
dataset_name = kwargs['dataset']
is_word_level, is_lower, is_alphanum_only = True, True, False
if dataset_name in ["Arabic", "Japanese", "Korean"] or "zh" in dataset_name:
is_word_level = False
if "multi_scene_ocr" in self.group_name and is_word_level:
is_alphanum_only = True
eval_config = {"word_level": is_word_level, "alphanum_only": is_alphanum_only, "lowercase": is_lower}
image_pdt_info, image_gt_info = {}, {}
for file_name, gt_src in gt_info.items():
pred_src = response_info.get(file_name, "")
pdt_token_list = text_normalize_and_tokenize(
str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
gt_token_list = text_normalize_and_tokenize(
str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
image_pdt_info[file_name] = pdt_token_list
image_gt_info[file_name] = gt_token_list
eval_result = calculate_metrics(image_pdt_info, image_gt_info, is_verbose=False)
return {"summary": eval_result, "metric_config": eval_config}
if __name__ == '__main__':
pass