mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-o 2.6
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
# CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy
|
||||
|
||||
## Introduction
|
||||
|
||||
Please refer to our [GitHub](https://github.com/AlibabaResearch/AdvancedLiterateMachinery/tree/main/Benchmarks/CC-OCR) for more information.
|
||||
|
||||
## Running Scripts
|
||||
|
||||
Once the environment is ready, execute the following script from the root directory of VLMEvalKit
|
||||
to perform inference and evaluation tasks in batch.
|
||||
|
||||
```shell
|
||||
MODEL_NAME="QwenVLMax"
|
||||
OUTPUT_DIR="/your/path/to/output_dir"
|
||||
|
||||
SUB_OUTPUT_DIR=${OUTPUT_DIR}/multi_scene_ocr
|
||||
python run.py --data CCOCR_MultiSceneOcr_Cord CCOCR_MultiSceneOcr_Funsd CCOCR_MultiSceneOcr_Iam CCOCR_MultiSceneOcr_ZhDoc CCOCR_MultiSceneOcr_ZhHandwriting CCOCR_MultiSceneOcr_Hieragent CCOCR_MultiSceneOcr_Ic15 CCOCR_MultiSceneOcr_Inversetext CCOCR_MultiSceneOcr_Totaltext CCOCR_MultiSceneOcr_ZhScene CCOCR_MultiSceneOcr_UgcLaion CCOCR_MultiSceneOcr_ZhDense CCOCR_MultiSceneOcr_ZhVertical --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
|
||||
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
|
||||
|
||||
SUB_OUTPUT_DIR=${OUTPUT_DIR}/multi_lan_ocr
|
||||
python run.py --data CCOCR_MultiLanOcr_Arabic CCOCR_MultiLanOcr_French CCOCR_MultiLanOcr_German CCOCR_MultiLanOcr_Italian CCOCR_MultiLanOcr_Japanese CCOCR_MultiLanOcr_Korean CCOCR_MultiLanOcr_Portuguese CCOCR_MultiLanOcr_Russian CCOCR_MultiLanOcr_Spanish CCOCR_MultiLanOcr_Vietnamese --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
|
||||
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
|
||||
|
||||
SUB_OUTPUT_DIR=${OUTPUT_DIR}/doc_parsing
|
||||
python run.py --data CCOCR_DocParsing_DocPhotoChn CCOCR_DocParsing_DocPhotoEng CCOCR_DocParsing_DocScanChn CCOCR_DocParsing_DocScanEng CCOCR_DocParsing_TablePhotoChn CCOCR_DocParsing_TablePhotoEng CCOCR_DocParsing_TableScanChn CCOCR_DocParsing_TableScanEng CCOCR_DocParsing_MolecularHandwriting CCOCR_DocParsing_FormulaHandwriting --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
|
||||
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
|
||||
|
||||
SUB_OUTPUT_DIR=${OUTPUT_DIR}/kie
|
||||
python run.py --data CCOCR_Kie_Sroie2019Word CCOCR_Kie_Cord CCOCR_Kie_EphoieScut CCOCR_Kie_Poie CCOCR_Kie_ColdSibr CCOCR_Kie_ColdCell --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
|
||||
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
|
||||
```
|
||||
|
||||
## Example Output
|
||||
The evaluation results will be saved in `${SUB_OUTPUT_DIR}/summary.md`. For example, for the KIE subset,
|
||||
the output is as follows:
|
||||
|
||||
| exp_name(f1_score) | COLD_CELL | COLD_SIBR | CORD | EPHOIE_SCUT | POIE | sroie2019_word | summary |
|
||||
|:-------------------|------------:|------------:|-------:|--------------:|-------:|-----------------:|----------:|
|
||||
| QwenVLMax | 81.01 | 72.46 | 69.33 | 71.2 | 60.85 | 76.37 | 71.87 |
|
||||
|
||||
|
||||
## Citation
|
||||
If you find our work helpful, feel free to give us a cite.
|
||||
|
||||
```
|
||||
@misc{yang2024ccocr,
|
||||
title={CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy},
|
||||
author={Zhibo Yang and Jun Tang and Zhaohai Li and Pengfei Wang and Jianqiang Wan and Humen Zhong and Xuejing Liu and Mingkun Yang and Peng Wang and Shuai Bai and LianWen Jin and Junyang Lin},
|
||||
year={2024},
|
||||
eprint={2412.02210},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV},
|
||||
url={https://arxiv.org/abs/2412.02210},
|
||||
}
|
||||
```
|
||||
|
||||
## Contact Us
|
||||
|
||||
If you have any questions, feel free to send an email to: wpf272043@alibaba-inc.com or xixing.tj@alibaba-inc.com
|
||||
@@ -0,0 +1,12 @@
|
||||
from .kie_evaluator import KieEvaluator
|
||||
from .doc_parsing_evaluator import ParsingEvaluator
|
||||
from .ocr_evaluator import OcrEvaluator
|
||||
from .common import summary
|
||||
|
||||
|
||||
evaluator_map_info = {
|
||||
"kie": KieEvaluator("kie"),
|
||||
"doc_parsing": ParsingEvaluator("doc_parsing"),
|
||||
"multi_lan_ocr": OcrEvaluator("multi_lan_ocr"),
|
||||
"multi_scene_ocr": OcrEvaluator("multi_scene_ocr")
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
def pick_response_text(json_path):
|
||||
"""
|
||||
"""
|
||||
try:
|
||||
with open(json_path, "r") as f:
|
||||
json_data = json.load(f)
|
||||
except Exception as e:
|
||||
print("--> file error: msg: {}, path: {}".format(e, json_path))
|
||||
return None
|
||||
|
||||
for required_key in ["model_name", "response"]:
|
||||
if required_key not in json_data:
|
||||
print("--> required key not exists, name: {}, path: {}".format(required_key, json_path))
|
||||
return None
|
||||
|
||||
model_name = json_data["model_name"]
|
||||
model_response = json_data["response"]
|
||||
|
||||
response_text = None
|
||||
if model_name.startswith("gpt") or model_name.startswith("o1"):
|
||||
response_text = model_response.get("data", {}).get("response", {}).get("choices", [{}])[0].get("message", {}).get("content", None) # noqa: E501
|
||||
elif model_name.startswith("local_"):
|
||||
response_text = model_response
|
||||
else:
|
||||
if model_name.startswith("claude"):
|
||||
content_list = model_response.get("content", None)
|
||||
elif model_name.startswith("gemini"):
|
||||
content_list = model_response.get("candidates", [{}])[0].get("content", {}).get("parts", None)
|
||||
elif model_name.startswith("qwen"):
|
||||
content_list = model_response.get("output", {}).get("choices", [{}])[0].get("message", {}).get("content", None) # noqa: E501
|
||||
else:
|
||||
raise NotImplementedError("The pick_response_text NOT implemented for model: {}".format(model_name))
|
||||
|
||||
if isinstance(content_list, list) and len(content_list) > 0:
|
||||
response_text = content_list[0].get("text", None)
|
||||
|
||||
if response_text is None:
|
||||
print("--> [error][{}] text pick error, path: {}".format(model_name, json_path))
|
||||
return response_text
|
||||
|
||||
|
||||
def load_response_from_dir(res_dir):
|
||||
"""
|
||||
"""
|
||||
response_info = {}
|
||||
for file_name in os.listdir(res_dir):
|
||||
file_path = os.path.abspath(os.path.join(res_dir, file_name))
|
||||
if not file_name.endswith(".json"):
|
||||
print("--> skip: result file should be a json: but got: {}".format(file_path))
|
||||
continue
|
||||
|
||||
response_text = pick_response_text(file_path)
|
||||
if response_text is None:
|
||||
continue
|
||||
|
||||
file_name_wo_ext, ext = os.path.splitext(file_name)
|
||||
response_info[file_name_wo_ext] = response_text
|
||||
return response_info
|
||||
|
||||
|
||||
class BaseMetric(object):
|
||||
""" BaseMetric """
|
||||
""" OCRMetric """
|
||||
def __init__(self, group_name, **kwargs):
|
||||
self.group_name = group_name
|
||||
self.kwargs = kwargs
|
||||
|
||||
def response_post_func(self, response_text, **kwargs):
|
||||
return response_text
|
||||
|
||||
@abstractmethod
|
||||
# Given the prediction and gt, return the evaluation results in the format of a dictionary
|
||||
# results should contain a 'summary' key, for example:
|
||||
# {
|
||||
# "summary": {
|
||||
# "f1-score": 99.99,
|
||||
# "metric_name": "metric_value" # used for summary,only metric info could be placed in this dict.
|
||||
# },
|
||||
# "your other info": "xxx"
|
||||
# }
|
||||
def evaluate(self, response_info, gt_info, normalize_func=None, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, pdt_res_dir, gt_info, with_response_ratio=True, **kwargs):
|
||||
if isinstance(pdt_res_dir, dict):
|
||||
raw_response_info = pdt_res_dir
|
||||
elif os.path.exists(pdt_res_dir) and os.path.isdir(pdt_res_dir):
|
||||
raw_response_info = load_response_from_dir(pdt_res_dir)
|
||||
else:
|
||||
return ValueError("invalid input: response dict or folder are required, but got {}".format(pdt_res_dir))
|
||||
|
||||
post_error_list, response_info = [], {}
|
||||
response_error_list = list(gt_info.keys() - raw_response_info.keys())
|
||||
for file_name, single_pdt_str in raw_response_info.items():
|
||||
single_pdt_str = self.response_post_func(single_pdt_str, **kwargs)
|
||||
if single_pdt_str is None:
|
||||
post_error_list.append(file_name)
|
||||
continue
|
||||
response_info[file_name] = single_pdt_str
|
||||
|
||||
meta_info = {
|
||||
"gt_total_num": len(gt_info), "pdt_total_num": len(response_info),
|
||||
"post_error_list": post_error_list, "response_error_list": response_error_list,
|
||||
}
|
||||
eval_info = self.evaluate(response_info, gt_info, **kwargs)
|
||||
|
||||
# add response_success_ratio
|
||||
if "summary" in eval_info and with_response_ratio:
|
||||
success_ratio = (len(response_info) + len(post_error_list)) / (len(gt_info) + 1e-9)
|
||||
eval_info["summary"].update({"response_success_ratio": success_ratio})
|
||||
return meta_info, eval_info
|
||||
|
||||
|
||||
def summary(index_path, exp_dir_base, is_weighted_sum=False):
|
||||
"""
|
||||
"""
|
||||
with open(index_path, "r") as f:
|
||||
data_list = json.load(f)
|
||||
|
||||
all_data_info = {}
|
||||
for data_info_item in data_list:
|
||||
data_name = data_info_item["dataset"]
|
||||
if not data_info_item.get("release", True):
|
||||
continue
|
||||
all_data_info[data_name] = data_info_item
|
||||
dataset_list = list(all_data_info.keys())
|
||||
summary_path = summary_multi_exp(exp_dir_base, dataset_list, is_weighted_sum=is_weighted_sum)
|
||||
return summary_path
|
||||
|
||||
|
||||
def summary_multi_exp(exp_dir_base, dataset_list=None, is_weighted_sum=False):
|
||||
"""
|
||||
"""
|
||||
if dataset_list is None:
|
||||
all_dataset_name = []
|
||||
for exp_name in os.listdir(exp_dir_base):
|
||||
dir_status_path = os.path.join(exp_dir_base, exp_name, "status.json")
|
||||
if not os.path.exists(dir_status_path):
|
||||
continue
|
||||
with open(dir_status_path, "r") as f:
|
||||
data_status_info = json.load(f)
|
||||
all_dataset_name.extend(data_status_info.keys())
|
||||
dataset_list = sorted(set(all_dataset_name))
|
||||
|
||||
# summary main code
|
||||
all_evaluate_info, _ = {}, 0
|
||||
for exp_name in os.listdir(exp_dir_base):
|
||||
dir_status_path = os.path.join(exp_dir_base, exp_name, "status.json")
|
||||
if not os.path.exists(dir_status_path):
|
||||
print("--> skip: status.json not exist: {}".format(dir_status_path))
|
||||
continue
|
||||
|
||||
with open(dir_status_path, "r") as f:
|
||||
all_status_info = json.load(f)
|
||||
|
||||
for data_name in dataset_list:
|
||||
total_num = all_status_info.get(data_name, {}).get("config", {}).get("num", "-1")
|
||||
summary_info = all_status_info.get(data_name, {}).get("evaluation", {}).get("summary", {})
|
||||
for metric_name, metric_value in summary_info.items():
|
||||
if metric_name not in all_evaluate_info:
|
||||
all_evaluate_info[metric_name] = {}
|
||||
if exp_name not in all_evaluate_info[metric_name]:
|
||||
all_evaluate_info[metric_name][exp_name] = {}
|
||||
all_evaluate_info[metric_name][exp_name][data_name] = (metric_value, total_num)
|
||||
|
||||
all_table_md = []
|
||||
for metric_name, metric_info in all_evaluate_info.items():
|
||||
formatted_time = time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time()))
|
||||
summary_line_list = []
|
||||
summary_key_name = "summary(weighted)" if is_weighted_sum else "summary"
|
||||
summary_head = [f"exp_name({metric_name}_{formatted_time})"] + dataset_list + [summary_key_name]
|
||||
for exp_name, data_eval_info in metric_info.items():
|
||||
summary_line = [exp_name, ]
|
||||
|
||||
all_metric_value = 0
|
||||
is_summary_valid, all_total_num, all_weighted_metric = True, 0, 0
|
||||
for data_name in dataset_list:
|
||||
metric_value, total_num = data_eval_info.get(data_name, ("-1", "-1"))
|
||||
summary_line.append("{:.2f}".format(float(metric_value) * 100))
|
||||
if str(metric_value) == "-1" or str(metric_value) == "-1":
|
||||
is_summary_valid = False
|
||||
continue
|
||||
|
||||
all_total_num += float(total_num)
|
||||
all_weighted_metric += float(total_num) * float(metric_value)
|
||||
all_metric_value += float(metric_value)
|
||||
|
||||
summary_value_valid = ((all_weighted_metric / (all_total_num + 1e-9)) * 100) if is_weighted_sum \
|
||||
else (all_metric_value / (len(dataset_list) + 1e-9) * 100)
|
||||
summary_value = "-" if not is_summary_valid else "{:.2f}".format(summary_value_valid)
|
||||
summary_line.append(summary_value)
|
||||
summary_line_list.append(summary_line)
|
||||
|
||||
md_table_info = tabulate(summary_line_list, headers=summary_head, tablefmt='pipe')
|
||||
all_table_md.append(md_table_info)
|
||||
|
||||
print("\n\n".join(all_table_md))
|
||||
summary_path = os.path.abspath(os.path.join(exp_dir_base, "summary.md"))
|
||||
with open(summary_path, "w") as f:
|
||||
f.write("\n\n".join(all_table_md))
|
||||
return summary_path
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python {} exp_base_dir".format(__file__))
|
||||
exit(-1)
|
||||
else:
|
||||
print('--> info: {}'.format(sys.argv))
|
||||
exp_base_dir = sys.argv[1]
|
||||
|
||||
summary_path = summary_multi_exp(exp_base_dir, dataset_list=None, is_weighted_sum=False)
|
||||
print("--> info: summary saved at : {}".format(summary_path))
|
||||
print("happy coding.")
|
||||
@@ -0,0 +1,256 @@
|
||||
import nltk
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
from collections import deque
|
||||
from apted.helpers import Tree
|
||||
from apted import APTED, Config
|
||||
|
||||
# local import
|
||||
from .common import BaseMetric
|
||||
|
||||
|
||||
# 移除指定的LaTeX命令
|
||||
patterns = [
|
||||
r'\\documentclass\{.*?\}',
|
||||
r'\\usepackage\[.*?\]\{.*?\}',
|
||||
r'\\usepackage\{.*?\}',
|
||||
r'\\geometry\{.*?\}',
|
||||
r'\\begin\{document\}',
|
||||
r'\\end\{document\}',
|
||||
r'\\noindent'
|
||||
]
|
||||
|
||||
|
||||
class TableTree(Tree):
|
||||
"""
|
||||
# Copyright 2020 IBM
|
||||
# Author: peter.zhong@au1.ibm.com
|
||||
# License: Apache 2.0 License.
|
||||
"""
|
||||
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
||||
self.tag = tag
|
||||
self.colspan = colspan
|
||||
self.rowspan = rowspan
|
||||
self.content = content
|
||||
self.children = list(children)
|
||||
|
||||
def bracket(self):
|
||||
"""Show tree using brackets notation"""
|
||||
if self.tag == "td":
|
||||
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
|
||||
self.tag,
|
||||
self.colspan,
|
||||
self.rowspan,
|
||||
self.content,
|
||||
)
|
||||
else:
|
||||
result = '"tag": %s' % self.tag
|
||||
for child in self.children:
|
||||
result += child.bracket()
|
||||
return "{{{}}}".format(result)
|
||||
|
||||
|
||||
class CustomConfig(Config):
|
||||
"""
|
||||
# Copyright 2020 IBM
|
||||
# Author: peter.zhong@au1.ibm.com
|
||||
# License: Apache 2.0 License.
|
||||
"""
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
# print(node1.tag)
|
||||
if (
|
||||
(node1.tag != node2.tag)
|
||||
or (node1.colspan != node2.colspan)
|
||||
or (node1.rowspan != node2.rowspan)
|
||||
):
|
||||
return 1.0
|
||||
if node1.tag == "td":
|
||||
if node1.content or node2.content:
|
||||
return nltk.edit_distance(node1.content, node2.content) / max(len(node1.content), len(node2.content))
|
||||
return 0.0
|
||||
|
||||
|
||||
class TEDS(object):
|
||||
"""Tree Edit Distance basead Similarity
|
||||
# Copyright 2020 IBM
|
||||
# Author: peter.zhong@au1.ibm.com
|
||||
# License: Apache 2.0 License.
|
||||
"""
|
||||
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
||||
assert isinstance(n_jobs, int) and (
|
||||
n_jobs >= 1
|
||||
), "n_jobs must be an integer greather than 1"
|
||||
self.structure_only = structure_only
|
||||
self.n_jobs = n_jobs
|
||||
self.ignore_nodes = ignore_nodes
|
||||
self.__tokens__ = []
|
||||
|
||||
def tokenize(self, node):
|
||||
"""Tokenizes table cells"""
|
||||
self.__tokens__.append("<%s>" % node.tag)
|
||||
if node.text is not None:
|
||||
self.__tokens__ += list(node.text)
|
||||
for n in node.getchildren():
|
||||
self.tokenize(n)
|
||||
if node.tag != "unk":
|
||||
self.__tokens__.append("</%s>" % node.tag)
|
||||
if node.tag != "td" and node.tail is not None:
|
||||
self.__tokens__ += list(node.tail)
|
||||
|
||||
def load_html_tree(self, node, parent=None):
|
||||
"""Converts HTML tree to the format required by apted"""
|
||||
global __tokens__
|
||||
if node.tag == "td":
|
||||
if self.structure_only:
|
||||
cell = []
|
||||
else:
|
||||
self.__tokens__ = []
|
||||
self.tokenize(node)
|
||||
cell = self.__tokens__[1:-1].copy()
|
||||
new_node = TableTree(
|
||||
node.tag,
|
||||
int(node.attrib.get("colspan", "1")),
|
||||
int(node.attrib.get("rowspan", "1")),
|
||||
cell,
|
||||
*deque(),
|
||||
)
|
||||
else:
|
||||
new_node = TableTree(node.tag, None, None, None, *deque())
|
||||
if parent is not None:
|
||||
parent.children.append(new_node)
|
||||
if node.tag != "td":
|
||||
for n in node.getchildren():
|
||||
self.load_html_tree(n, new_node)
|
||||
if parent is None:
|
||||
return new_node
|
||||
|
||||
def evaluate(self, pred, true):
|
||||
"""Computes TEDS score between the prediction and the ground truth of a
|
||||
given sample
|
||||
"""
|
||||
# try_import("lxml")
|
||||
from lxml import etree, html
|
||||
if (not pred) or (not true):
|
||||
return 0.0
|
||||
|
||||
parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
|
||||
pred = html.fromstring(pred, parser=parser)
|
||||
true = html.fromstring(true, parser=parser)
|
||||
if pred.xpath("body/table") and true.xpath("body/table"):
|
||||
pred = pred.xpath("body/table")[0]
|
||||
true = true.xpath("body/table")[0]
|
||||
if self.ignore_nodes:
|
||||
etree.strip_tags(pred, *self.ignore_nodes)
|
||||
etree.strip_tags(true, *self.ignore_nodes)
|
||||
n_nodes_pred = len(pred.xpath(".//*"))
|
||||
n_nodes_true = len(true.xpath(".//*"))
|
||||
n_nodes = max(n_nodes_pred, n_nodes_true)
|
||||
tree_pred = self.load_html_tree(pred)
|
||||
tree_true = self.load_html_tree(true)
|
||||
distance = APTED(
|
||||
tree_pred, tree_true, CustomConfig()
|
||||
).compute_edit_distance()
|
||||
return 1.0 - (float(distance) / n_nodes)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
class ParsingEvaluator(BaseMetric):
|
||||
def response_post_func(self, response_text, **kwargs):
|
||||
return response_text
|
||||
|
||||
def evaluate(self, response_info, gt_info, **kwargs):
|
||||
op = kwargs['op']
|
||||
if op == 'doc':
|
||||
score = self.eval_doc(response_info, gt_info)
|
||||
elif op == 'table':
|
||||
score = self.eval_table(response_info, gt_info)
|
||||
elif op in ['molecular', "formula"]:
|
||||
score = self.eval_formula(response_info, gt_info, op_name=op)
|
||||
else:
|
||||
raise ValueError(f'doc parsing unsupported op: {op}')
|
||||
|
||||
# summary info
|
||||
eval_info = {"summary": {"score": score}}
|
||||
return eval_info
|
||||
|
||||
def eval_doc(self, response_info, gt_info):
|
||||
results = []
|
||||
for img_name, gt in tqdm(gt_info.items()):
|
||||
if img_name not in response_info:
|
||||
results.append(0)
|
||||
continue
|
||||
|
||||
pred = response_info[img_name]
|
||||
for pattern in patterns:
|
||||
pred = re.sub(pattern, '', pred)
|
||||
|
||||
try:
|
||||
pred = pred.split('```')[1]
|
||||
except:
|
||||
pass
|
||||
|
||||
pred = pred.replace('```latex', '')
|
||||
pred = pred.replace('```', '')
|
||||
|
||||
pred = pred.replace(' ', '').replace('\n', '')
|
||||
gt = gt.replace(' ', '').replace('\n', '')
|
||||
|
||||
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
|
||||
results.append(1 - edit_dist)
|
||||
|
||||
score = sum(results) / len(results)
|
||||
return score
|
||||
|
||||
def eval_table(self, response_info, gt_info):
|
||||
teds = TEDS(structure_only=False, n_jobs=1)
|
||||
results = []
|
||||
for img_name, gt in tqdm(gt_info.items()):
|
||||
if img_name not in response_info:
|
||||
results.append(0)
|
||||
continue
|
||||
|
||||
pred = response_info[img_name]
|
||||
for pattern in patterns:
|
||||
pred = re.sub(pattern, '', pred)
|
||||
|
||||
try:
|
||||
pred = pred.split('```html')[1]
|
||||
except:
|
||||
pass
|
||||
|
||||
pred = pred.replace('```', '')
|
||||
pred = pred.replace(' ', '').replace('\n', '').replace(',', ',')
|
||||
gt = gt.replace(' ', '').replace('\n', '')
|
||||
|
||||
pred_html = '<html><body>{}</body></html>'.format(pred)
|
||||
gt_html = '<html><body>{}</body></html>'.format(gt)
|
||||
results.append(teds.evaluate(pred_html, gt_html))
|
||||
|
||||
score = sum(results) / len(results)
|
||||
return score
|
||||
|
||||
def eval_formula(self, response_info, gt_info, op_name='formula'):
|
||||
results = []
|
||||
for img_name, gt in tqdm(gt_info.items()):
|
||||
if img_name not in response_info:
|
||||
results.append(0)
|
||||
continue
|
||||
|
||||
pred = response_info[img_name]
|
||||
|
||||
if op_name == 'formula':
|
||||
pred = pred.replace("\n", " ").replace("```latex", "").replace("```", "").replace("\t", " ").replace(" ", "") # noqa: E501
|
||||
gt = gt.replace(" ", "")
|
||||
elif op_name == 'molecular':
|
||||
pred = pred.replace("\n", "").replace(" ", "").replace("<smiles>", "").replace("</smiles>", "")
|
||||
gt = gt.replace(" ", "")
|
||||
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
|
||||
results.append(1 - edit_dist)
|
||||
score = sum(results) / len(results)
|
||||
return score
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
@@ -0,0 +1,385 @@
|
||||
|
||||
"""
|
||||
Donut
|
||||
Copyright (c) 2022-present NAVER Corp.
|
||||
MIT License
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import zss
|
||||
from zss import Node
|
||||
from collections import Counter
|
||||
from nltk import edit_distance
|
||||
|
||||
# local import
|
||||
from .common import BaseMetric
|
||||
|
||||
|
||||
def flatten(data: dict):
|
||||
"""
|
||||
Convert Dictionary into Non-nested Dictionary
|
||||
Example:
|
||||
input(dict)
|
||||
{
|
||||
"menu": [
|
||||
{"name" : ["cake"], "count" : ["2"]},
|
||||
{"name" : ["juice"], "count" : ["1"]},
|
||||
]
|
||||
}
|
||||
output(list)
|
||||
[
|
||||
("menu.name", "cake"),
|
||||
("menu.count", "2"),
|
||||
("menu.name", "juice"),
|
||||
("menu.count", "1"),
|
||||
]
|
||||
"""
|
||||
flatten_data = list()
|
||||
|
||||
def _flatten(value, key=""):
|
||||
if type(value) is dict:
|
||||
for child_key, child_value in value.items():
|
||||
_flatten(child_value, f"{key}.{child_key}" if key else child_key)
|
||||
elif type(value) is list:
|
||||
for value_item in value:
|
||||
_flatten(value_item, key)
|
||||
else:
|
||||
flatten_data.append((key, value))
|
||||
|
||||
_flatten(data)
|
||||
return flatten_data
|
||||
|
||||
|
||||
def update_cost(node1: Node, node2: Node):
|
||||
"""
|
||||
Update cost for tree edit distance.
|
||||
If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored).
|
||||
If one of them is leaf node, cost is length of string in leaf node + 1.
|
||||
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
|
||||
"""
|
||||
label1 = node1.label
|
||||
label2 = node2.label
|
||||
label1_leaf = "<leaf>" in label1
|
||||
label2_leaf = "<leaf>" in label2
|
||||
if label1_leaf and label2_leaf:
|
||||
return edit_distance(label1.replace("<leaf>", ""), label2.replace("<leaf>", ""))
|
||||
elif not label1_leaf and label2_leaf:
|
||||
return 1 + len(label2.replace("<leaf>", ""))
|
||||
elif label1_leaf and not label2_leaf:
|
||||
return 1 + len(label1.replace("<leaf>", ""))
|
||||
else:
|
||||
return int(label1 != label2)
|
||||
|
||||
|
||||
def insert_and_remove_cost(node: Node):
|
||||
"""
|
||||
Insert and remove cost for tree edit distance.
|
||||
If leaf node, cost is length of label name.
|
||||
Otherwise, 1
|
||||
"""
|
||||
label = node.label
|
||||
if "<leaf>" in label:
|
||||
return len(label.replace("<leaf>", ""))
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def normalize_dict(data: Union[Dict, List, Any]):
|
||||
"""
|
||||
Sort by value, while iterate over element if data is list
|
||||
"""
|
||||
# if not data:
|
||||
# return {}
|
||||
|
||||
if isinstance(data, dict):
|
||||
new_data = dict()
|
||||
for key in sorted(data.keys(), key=lambda k: (len(k), k)):
|
||||
value = normalize_dict(data[key])
|
||||
if value:
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
new_data[key] = value
|
||||
|
||||
elif isinstance(data, list):
|
||||
if all(isinstance(item, dict) for item in data):
|
||||
new_data = []
|
||||
for item in data:
|
||||
item = normalize_dict(item)
|
||||
if item:
|
||||
new_data.append(item)
|
||||
else:
|
||||
new_data = [str(item).strip() for item in data if type(item) in {str, int, float} and str(item).strip()]
|
||||
else:
|
||||
new_data = [str(data).strip()]
|
||||
return new_data
|
||||
|
||||
|
||||
def cal_f1_all(preds, answers):
|
||||
"""
|
||||
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives,
|
||||
false negatives and false positives
|
||||
"""
|
||||
metric_info, error_info = {}, {}
|
||||
total_tp, total_fn_or_fp = 0, 0
|
||||
for file_name, answer in answers.items():
|
||||
sample_error_info = {"fp": [], "fn": [], "tp": []}
|
||||
pred = preds.get(file_name, {})
|
||||
pred, answer = flatten(normalize_dict(pred)), flatten(normalize_dict(answer))
|
||||
for field in pred:
|
||||
field_name = field[0]
|
||||
if field_name not in metric_info:
|
||||
metric_info[field_name] = {"total_tp": 0, "total_fn_or_fp": 0}
|
||||
if field in answer:
|
||||
total_tp += 1
|
||||
metric_info[field_name]["total_tp"] += 1
|
||||
sample_error_info["tp"].append(field)
|
||||
answer.remove(field)
|
||||
else:
|
||||
total_fn_or_fp += 1
|
||||
metric_info[field_name]["total_fn_or_fp"] += 1
|
||||
sample_error_info["fp"].append(field)
|
||||
|
||||
total_fn_or_fp += len(answer)
|
||||
for field in answer:
|
||||
field_name = field[0]
|
||||
if field_name not in metric_info:
|
||||
metric_info[field_name] = {"total_tp": 0, "total_fn_or_fp": 0}
|
||||
metric_info[field_name]["total_fn_or_fp"] += 1
|
||||
sample_error_info["fn"].append(field)
|
||||
|
||||
sample_error_num = sum([len(v) for k, v in sample_error_info.items() if k != "tp"])
|
||||
if sample_error_num > 0:
|
||||
sample_error_info["error_num"] = sample_error_num
|
||||
error_class_list = ["counter_" + x[0] for x in (sample_error_info["fn"] + sample_error_info["fp"])]
|
||||
counter = Counter(error_class_list)
|
||||
sample_error_info["error_info"] = dict(counter)
|
||||
error_info[file_name] = sample_error_info
|
||||
|
||||
# summary
|
||||
for field_name, field_info in metric_info.items():
|
||||
field_tp, field_fn_or_fp = field_info["total_tp"], field_info["total_fn_or_fp"]
|
||||
metric_info[field_name]["acc"] = field_tp / (field_tp + field_fn_or_fp / 2 + 1e-6)
|
||||
|
||||
print("donut_evaluator: total_tp: {}, total_fn_or_fp: {}, ptd_num: {}, gt_num: {}".format(total_tp, total_fn_or_fp,
|
||||
len(preds), len(answers)))
|
||||
error_info = {k: v for k, v in
|
||||
sorted(error_info.items(), key=lambda item: item[1].get("error_num", 0), reverse=True)}
|
||||
metric_info = {k: v for k, v in
|
||||
sorted(metric_info.items(), key=lambda item: item[1].get("total_fn_or_fp", 0), reverse=True)}
|
||||
return total_tp / (total_tp + total_fn_or_fp / 2 + 1e-6), metric_info, error_info
|
||||
|
||||
|
||||
def construct_tree_from_dict(data: Union[Dict, List], node_name: str = None):
|
||||
"""
|
||||
Convert Dictionary into Tree
|
||||
|
||||
Example:
|
||||
input(dict)
|
||||
|
||||
{
|
||||
"menu": [
|
||||
{"name" : ["cake"], "count" : ["2"]},
|
||||
{"name" : ["juice"], "count" : ["1"]},
|
||||
]
|
||||
}
|
||||
|
||||
output(tree)
|
||||
<root>
|
||||
|
|
||||
menu
|
||||
/ \
|
||||
<subtree> <subtree>
|
||||
/ | | \
|
||||
name count name count
|
||||
/ | | \
|
||||
<leaf>cake <leaf>2 <leaf>juice <leaf>1
|
||||
"""
|
||||
if node_name is None:
|
||||
node_name = "<root>"
|
||||
|
||||
node = Node(node_name)
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
kid_node = construct_tree_from_dict(value, key)
|
||||
node.addkid(kid_node)
|
||||
elif isinstance(data, list):
|
||||
if all(isinstance(item, dict) for item in data):
|
||||
for item in data:
|
||||
kid_node = construct_tree_from_dict(
|
||||
item,
|
||||
"<subtree>",
|
||||
)
|
||||
node.addkid(kid_node)
|
||||
else:
|
||||
for item in data:
|
||||
node.addkid(Node(f"<leaf>{item}"))
|
||||
else:
|
||||
raise Exception(data, node_name)
|
||||
return node
|
||||
|
||||
|
||||
def cal_acc(pred: dict, answer: dict):
|
||||
"""
|
||||
Calculate normalized tree edit distance(nTED) based accuracy.
|
||||
1) Construct tree from dict,
|
||||
2) Get tree distance with insert/remove/update cost,
|
||||
3) Divide distance with GT tree size (i.e., nTED),
|
||||
4) Calculate nTED based accuracy. (= max(1 - nTED, 0 ).
|
||||
"""
|
||||
pred = construct_tree_from_dict(normalize_dict(pred))
|
||||
answer = construct_tree_from_dict(normalize_dict(answer))
|
||||
val1 = zss.distance(
|
||||
pred,
|
||||
answer,
|
||||
get_children=zss.Node.get_children,
|
||||
insert_cost=insert_and_remove_cost,
|
||||
remove_cost=insert_and_remove_cost,
|
||||
update_cost=update_cost,
|
||||
return_operations=False,
|
||||
)
|
||||
val2 = zss.distance(
|
||||
construct_tree_from_dict(normalize_dict({})),
|
||||
answer,
|
||||
get_children=zss.Node.get_children,
|
||||
insert_cost=insert_and_remove_cost,
|
||||
remove_cost=insert_and_remove_cost,
|
||||
update_cost=update_cost,
|
||||
return_operations=False,
|
||||
)
|
||||
return max(0, 1 - val1 / val2)
|
||||
|
||||
|
||||
def cal_acc_all(pred_info, answer_info):
|
||||
acc_info, error_info = {}, {}
|
||||
for file_name, answer in answer_info.items():
|
||||
# if file_name not in pred_info:
|
||||
# print("---> error: pdt not found: {}".format(file_name))
|
||||
# continue
|
||||
pred = pred_info.get(file_name, {})
|
||||
acc = cal_acc(pred, answer)
|
||||
acc_info[file_name] = acc
|
||||
if acc < 1.0:
|
||||
error_info[file_name] = {"acc": acc, "pred": pred, "answer": answer}
|
||||
|
||||
error_info = {k: v for k, v in sorted(error_info.items(), key=lambda item: item[1].get("acc", 0))}
|
||||
acc_averge = sum(list(acc_info.values())) / (len(acc_info) + 1e-6)
|
||||
return acc_averge, error_info
|
||||
|
||||
|
||||
def normalize_values_of_nested_dict(d, normalize_func):
|
||||
"""
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
return {k: normalize_values_of_nested_dict(v, normalize_func) for k, v in d.items()}
|
||||
elif isinstance(d, list):
|
||||
return [normalize_values_of_nested_dict(x, normalize_func) if isinstance(x, dict) else x for x in d]
|
||||
elif isinstance(d, str):
|
||||
return normalize_func(d)
|
||||
else:
|
||||
return d
|
||||
|
||||
|
||||
def eval_donut(pdt_info, gt_info, normalize_func=None, data_name=None):
|
||||
"""
|
||||
"""
|
||||
if normalize_func is not None:
|
||||
print("--> info: normalize_func executed.")
|
||||
pdt_info = normalize_values_of_nested_dict(pdt_info, normalize_func)
|
||||
gt_info = normalize_values_of_nested_dict(gt_info, normalize_func)
|
||||
|
||||
f1_score, class_eval_info, error_info = cal_f1_all(pdt_info, gt_info)
|
||||
acc_average, acc_error_info = cal_acc_all(pdt_info, gt_info)
|
||||
eval_info = {"f1_score": f1_score, "acc": acc_average, "class_f1_score": class_eval_info,
|
||||
"f1_error_info": error_info, "acc_error_info": acc_error_info}
|
||||
print(data_name, "f1_score", f1_score, "acc", acc_average)
|
||||
return eval_info
|
||||
|
||||
|
||||
def post_process_to_json(qwen_info_str, file_name=None):
|
||||
try:
|
||||
if "```json" in qwen_info_str:
|
||||
if "```" not in qwen_info_str:
|
||||
qwen_info_str += "```"
|
||||
qwen_info_group = re.search(r'```json(.*?)```', qwen_info_str, re.DOTALL)
|
||||
json_str = qwen_info_group.group(1).strip().replace("\n", "")
|
||||
else:
|
||||
json_str = qwen_info_str.strip().replace("\n", "")
|
||||
json_data = json.loads(json_str)
|
||||
return json_data
|
||||
except Exception as err: # noqa: F841
|
||||
return None
|
||||
|
||||
|
||||
def fullwidth_to_halfwidth(text):
|
||||
# 全角转半角
|
||||
result = ''
|
||||
for char in text:
|
||||
code_point = ord(char)
|
||||
# 全角空格直接转化
|
||||
if code_point == 0x3000:
|
||||
code_point = 0x0020
|
||||
# 其他全角字符(除空格)转换为半角
|
||||
elif 0xFF01 <= code_point <= 0xFF5E:
|
||||
code_point -= 0xFEE0
|
||||
result += chr(code_point)
|
||||
result = result.replace("、", ",")
|
||||
return result
|
||||
|
||||
|
||||
def remove_unnecessary_spaces(text):
|
||||
# 去掉中文字符之间的空格
|
||||
text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])', '', text)
|
||||
# 去掉中文和英文、数字之间的空格
|
||||
text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[a-zA-Z0-9])', '', text)
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])\s+(?=[\u4e00-\u9fff])', '', text)
|
||||
# 去掉符号前的不必要空格,保留符号后的一个空格
|
||||
text = re.sub(r'(?<![0-9])\s*([,.!?:;])\s*', r'\1 ', text) # 非数字前后的符号
|
||||
# 在数字和英文之间添加空格
|
||||
text = re.sub(r'(?<=[0-9])(?=[a-zA-Z])', ' ', text)
|
||||
text = re.sub(r'(?<=[a-zA-Z])(?=[0-9])', ' ', text)
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text
|
||||
|
||||
|
||||
class KieEvaluator(BaseMetric):
|
||||
def response_post_func(self, response_text, **kwargs):
|
||||
response_text = post_process_to_json(response_text, file_name=kwargs.get('file_name', None))
|
||||
return response_text
|
||||
|
||||
def normalize_func(self, text, **kwargs):
|
||||
halfwidth_text = fullwidth_to_halfwidth(str(text))
|
||||
cleaned_text = remove_unnecessary_spaces(halfwidth_text)
|
||||
return cleaned_text
|
||||
|
||||
def evaluate(self, response_info, gt_info, **kwargs):
|
||||
"""
|
||||
response_info: dict: {"file_name_1": response, "file_name_2": gt}
|
||||
gt_info: dict: {"file_name_1": gt, "file_name_2": gt}
|
||||
kwargs: dataset index config: {'dataset': 'kie_benchmark_POIE', 'group': 'kie', 'op': 'poie', 'num': 250}
|
||||
"""
|
||||
# gt should be a dict for kie task, fix for VLMEvalKit
|
||||
for image_name, label_content in gt_info.items():
|
||||
if isinstance(label_content, str):
|
||||
gt_info[image_name] = json.loads(label_content)
|
||||
|
||||
response_info = normalize_values_of_nested_dict(response_info, self.normalize_func)
|
||||
gt_info = normalize_values_of_nested_dict(gt_info, self.normalize_func)
|
||||
|
||||
f1_score, class_eval_info, error_info = cal_f1_all(response_info, gt_info)
|
||||
acc_average, acc_error_info = cal_acc_all(response_info, gt_info)
|
||||
|
||||
# summary info
|
||||
summary_info = {"f1_score": f1_score, "acc": acc_average}
|
||||
eval_info = {"summary": summary_info, "class_f1_score": class_eval_info,
|
||||
"f1_error_info": error_info, "acc_error_info": acc_error_info}
|
||||
return eval_info
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
# local import
|
||||
from .common import BaseMetric
|
||||
|
||||
|
||||
def token_normalize(token_text, is_lower=False, is_alphanum_only=False):
|
||||
"""
|
||||
"""
|
||||
if is_lower:
|
||||
token_text = token_text.lower()
|
||||
if is_alphanum_only:
|
||||
token_text = re.sub('[^A-Za-z0-9]+', '', token_text)
|
||||
return token_text
|
||||
|
||||
|
||||
def text_normalize_and_tokenize(text, is_keep_blank=True, is_lower=True, is_alphanum_only=False):
|
||||
text = text.replace("\t", " ").replace("\n", " ").replace("###", "").replace("***", "")
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
if not is_keep_blank:
|
||||
text = text.replace(" ", "")
|
||||
text_tokens = text.split(" ") if is_keep_blank else list(text)
|
||||
text_token_normalized = [token_normalize(t, is_lower, is_alphanum_only) for t in text_tokens]
|
||||
text_token_normalized = [x for x in text_token_normalized if len(x) > 0]
|
||||
return text_token_normalized
|
||||
|
||||
|
||||
def evaluate_single_sample(gts, preds):
|
||||
right_num = 0
|
||||
gt_counter_info = dict(Counter(gts))
|
||||
pdt_counter_info = dict(Counter(preds))
|
||||
for gt_token, gt_count in gt_counter_info.items():
|
||||
pred_count = pdt_counter_info.get(gt_token, 0)
|
||||
right_num += min(gt_count, pred_count)
|
||||
return right_num
|
||||
|
||||
|
||||
def calculate_metrics(response_info, gt_info, is_verbose=False):
|
||||
"""
|
||||
"""
|
||||
macro_recall_list, macro_precision_list, macro_f1_list = [], [], []
|
||||
total_gt_num, total_pred_num, total_right_num = 0, 0, 0
|
||||
for file_name, fullbox_gts in gt_info.items():
|
||||
fullbox_preds = response_info.get(file_name, [])
|
||||
right_num = evaluate_single_sample(fullbox_gts, fullbox_preds)
|
||||
total_right_num += right_num
|
||||
total_gt_num += len(fullbox_gts)
|
||||
total_pred_num += len(fullbox_preds)
|
||||
|
||||
macro_recall = right_num / (len(fullbox_gts) + 1e-9)
|
||||
macro_precision = right_num / (len(fullbox_preds) + 1e-9)
|
||||
macro_f1 = 2 * macro_recall * macro_precision / (macro_recall + macro_precision + 1e-9)
|
||||
macro_recall_list.append(macro_recall)
|
||||
macro_precision_list.append(macro_precision)
|
||||
macro_f1_list.append(macro_f1)
|
||||
|
||||
# marco
|
||||
final_macro_recall = sum(macro_recall_list) / (len(macro_recall_list) + 1e-9)
|
||||
final_macro_precision = sum(macro_precision_list) / (len(macro_precision_list) + 1e-9)
|
||||
final_macro_f1 = sum(macro_f1_list) / (len(macro_f1_list) + 1e-9)
|
||||
|
||||
# micro
|
||||
recall_acc = total_right_num / (total_gt_num + 1e-9)
|
||||
preci_acc = total_right_num / (total_pred_num + 1e-9)
|
||||
hmean = 2 * recall_acc * preci_acc / (recall_acc + preci_acc + 1e-9)
|
||||
vbs_eval_result = {
|
||||
'macro_recall': final_macro_recall, 'macro_precision': final_macro_precision, 'macro_f1_score': final_macro_f1,
|
||||
'micro_recall': recall_acc, 'micro_precision': preci_acc, 'mirco_f1_score': hmean
|
||||
}
|
||||
eval_result = vbs_eval_result if is_verbose else {'macro_f1_score': final_macro_f1, 'mirco_f1_score': hmean}
|
||||
return eval_result
|
||||
|
||||
|
||||
class OcrEvaluator(BaseMetric):
|
||||
def response_post_func(self, response_text, **kwargs):
|
||||
return response_text
|
||||
|
||||
def evaluate(self, response_info, gt_info, **kwargs):
|
||||
# hard code here
|
||||
dataset_name = kwargs['dataset']
|
||||
is_word_level, is_lower, is_alphanum_only = True, True, False
|
||||
if dataset_name in ["Arabic", "Japanese", "Korean"] or "zh" in dataset_name:
|
||||
is_word_level = False
|
||||
if "multi_scene_ocr" in self.group_name and is_word_level:
|
||||
is_alphanum_only = True
|
||||
eval_config = {"word_level": is_word_level, "alphanum_only": is_alphanum_only, "lowercase": is_lower}
|
||||
|
||||
image_pdt_info, image_gt_info = {}, {}
|
||||
for file_name, gt_src in gt_info.items():
|
||||
pred_src = response_info.get(file_name, "")
|
||||
pdt_token_list = text_normalize_and_tokenize(
|
||||
str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
|
||||
gt_token_list = text_normalize_and_tokenize(
|
||||
str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
|
||||
image_pdt_info[file_name] = pdt_token_list
|
||||
image_gt_info[file_name] = gt_token_list
|
||||
eval_result = calculate_metrics(image_pdt_info, image_gt_info, is_verbose=False)
|
||||
return {"summary": eval_result, "metric_config": eval_config}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
Reference in New Issue
Block a user