Compare commits
123 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdfa37e768 | ||
|
|
587306e3a6 | ||
|
|
62327ab934 | ||
|
|
8bfe84fa7e | ||
|
|
f27b88951c | ||
|
|
784c46ba45 | ||
|
|
8b1edaeffd | ||
|
|
d364570bff | ||
|
|
47283856a3 | ||
|
|
06be4aa3d2 | ||
|
|
166d987e48 | ||
|
|
65932999c4 | ||
|
|
94aad96afd | ||
|
|
7770da6253 | ||
|
|
e3415db52b | ||
|
|
4d2de6f90f | ||
|
|
ae0df6cbb5 | ||
|
|
56517f7a27 | ||
|
|
1850a6b825 | ||
|
|
9cf46a3a2a | ||
|
|
3ed9416cd9 | ||
|
|
1a3f00479f | ||
|
|
8df2c5e786 | ||
|
|
a946b052c0 | ||
|
|
d8f382e157 | ||
|
|
ec68cefc17 | ||
|
|
3ad07c583a | ||
|
|
0b9fa42dd6 | ||
|
|
f8f4998a49 | ||
|
|
332ed62161 | ||
|
|
e29c918333 | ||
|
|
639cfc9412 | ||
|
|
03d2412085 | ||
|
|
4a011f46d1 | ||
|
|
2421a033fb | ||
|
|
2c093c2ab3 | ||
|
|
ea1b4acc86 | ||
|
|
c4117b72a6 | ||
|
|
4814f78a98 | ||
|
|
62c5e88a17 | ||
|
|
a84fd37adc | ||
|
|
10bee02ce5 | ||
|
|
5520d436ed | ||
|
|
2c0e7ff051 | ||
|
|
880095e28c | ||
|
|
cddeb03396 | ||
|
|
58f7328e7a | ||
|
|
fadf8c398a | ||
|
|
25e0d84d2f | ||
|
|
fd8674cc72 | ||
|
|
cd0972c7a1 | ||
|
|
f2d4d4b130 | ||
|
|
434e3874d3 | ||
|
|
8c19a59cf8 | ||
|
|
28f1ed2925 | ||
|
|
e81bf19555 | ||
|
|
613803f6f1 | ||
|
|
e1e04af112 | ||
|
|
40a54bb0e3 | ||
|
|
b178622f73 | ||
|
|
7d160d7aeb | ||
|
|
4977fec2ff | ||
|
|
f09ffe355a | ||
|
|
3c2f729530 | ||
|
|
f0d641b578 | ||
|
|
ce0955c0f4 | ||
|
|
5cfd89090e | ||
|
|
cbe7ade404 | ||
|
|
62034f183f | ||
|
|
2cece543fa | ||
|
|
1c51a220f0 | ||
|
|
516777e462 | ||
|
|
5e719efab0 | ||
|
|
495a810f87 | ||
|
|
806a1015d8 | ||
|
|
6ce232a06c | ||
|
|
b78137435a | ||
|
|
85a21c8dc7 | ||
|
|
088622f7be | ||
|
|
07afc8e39a | ||
|
|
53c0174797 | ||
|
|
b75a362dd6 | ||
|
|
4a087a8aec | ||
|
|
c5e82b1bc7 | ||
|
|
8464c94a7b | ||
|
|
2ab9fa7913 | ||
|
|
96c9e25287 | ||
|
|
8ff6cc0ed0 | ||
|
|
a209258d85 | ||
|
|
7bcca75e29 | ||
|
|
fd938af276 | ||
|
|
7ec8b3eca4 | ||
|
|
0cda63b309 | ||
|
|
f48f790d69 | ||
|
|
c541f1044e | ||
|
|
e56f2373f2 | ||
|
|
38c5495e1e | ||
|
|
fa25b3f20f | ||
|
|
59224808a1 | ||
|
|
b109c67478 | ||
|
|
344ddc2cb1 | ||
|
|
3745c3316a | ||
|
|
dc5f809253 | ||
|
|
48ed792ab8 | ||
|
|
ab1141ee45 | ||
|
|
421c6d7838 | ||
|
|
512d5a8bb0 | ||
|
|
1c31c6aa78 | ||
|
|
93aad9f29f | ||
|
|
c19087cd13 | ||
|
|
0fadd70c9e | ||
|
|
4b13c46dbb | ||
|
|
c9f5cd4b00 | ||
|
|
259d54ed0a | ||
|
|
0e4ec319cf | ||
|
|
b01d8e4adb | ||
|
|
970cea7d60 | ||
|
|
338892394f | ||
|
|
5553046db7 | ||
|
|
30b2446b0f | ||
|
|
cd64150b51 | ||
|
|
825abf10e2 | ||
|
|
ee458ad848 |
1693
README_en.md
2009
README_zh.md
BIN
assets/MiniCPM-V.jpg
Normal file
|
After Width: | Height: | Size: 163 KiB |
BIN
assets/MiniCPM-V27.jpg
Normal file
|
After Width: | Height: | Size: 163 KiB |
BIN
assets/MiniCPM-o.png
Normal file
|
After Width: | Height: | Size: 373 KiB |
BIN
assets/Minicpm-v 37.jpg
Normal file
|
After Width: | Height: | Size: 159 KiB |
BIN
assets/discord.png
Normal file
|
After Width: | Height: | Size: 272 B |
BIN
assets/input_examples/assistant_default_female_voice.wav
Normal file
BIN
assets/input_examples/assistant_female_voice.wav
Normal file
BIN
assets/input_examples/assistant_male_voice.wav
Normal file
BIN
assets/input_examples/audio_understanding.mp3
Normal file
BIN
assets/input_examples/chi-english-1.wav
Normal file
BIN
assets/input_examples/exciting-emotion.wav
Normal file
BIN
assets/input_examples/fast-pace.wav
Normal file
BIN
assets/input_examples/indian-accent.wav
Normal file
3
assets/logo.html
Normal file
@@ -0,0 +1,3 @@
|
||||
<span style="color:#56A7DA; font-size: 10em; font-weight: bold;">
|
||||
MiniCPM-<span>o</span>
|
||||
</span>
|
||||
BIN
assets/minicpm-o-26-framework-v2.png
Normal file
|
After Width: | Height: | Size: 1.1 MiB |
BIN
assets/minicpm-o-26-framework.png
Normal file
|
After Width: | Height: | Size: 1023 KiB |
BIN
assets/minicpm-o-group.jpeg
Normal file
|
After Width: | Height: | Size: 90 KiB |
|
Before Width: | Height: | Size: 112 KiB After Width: | Height: | Size: 47 KiB |
BIN
assets/minicpm-v26.png
Normal file
|
After Width: | Height: | Size: 47 KiB |
BIN
assets/minicpm-v27.png
Normal file
|
After Width: | Height: | Size: 148 KiB |
BIN
assets/minicpmo2_6/2dot6_o_demo_video_img.png
Normal file
|
After Width: | Height: | Size: 3.1 MiB |
BIN
assets/minicpmo2_6/minicpmo2_6_diagram_train_NN.png
Normal file
|
After Width: | Height: | Size: 1.8 MiB |
BIN
assets/minicpmo2_6/minicpmo2_6_math_intersect.png
Normal file
|
After Width: | Height: | Size: 785 KiB |
BIN
assets/minicpmo2_6/minicpmo2_6_multi-image_bike.png
Normal file
|
After Width: | Height: | Size: 8.6 MiB |
BIN
assets/minicpmo2_6/show_demo.jpg
Normal file
|
After Width: | Height: | Size: 100 KiB |
BIN
assets/minicpmv35-2.jpg
Normal file
|
After Width: | Height: | Size: 307 KiB |
BIN
assets/o-2dot6-demo-video-preview.png
Normal file
|
After Width: | Height: | Size: 2.6 MiB |
BIN
assets/radar.jpg
Normal file
|
After Width: | Height: | Size: 842 KiB |
BIN
assets/ref_audios/default.wav
Normal file
BIN
assets/ref_audios/female_example.wav
Normal file
BIN
assets/ref_audios/male_example.wav
Normal file
BIN
assets/ref_audios/video_default.wav
Normal file
BIN
assets/wechat-QR.jpeg
Normal file
|
After Width: | Height: | Size: 52 KiB |
BIN
assets/wechat.png
Normal file
|
After Width: | Height: | Size: 245 B |
23
docs/best_practice_summary.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# MiniCPM-V Best Practices
|
||||
|
||||
**MiniCPM-V** is a series of end-side multimodal LLMs (MLLMs) designed for vision-language understanding. The models take image, video and text as inputs and provide high-quality text output, aiming to achieve **strong performance and efficient deployment**. The most notable models in this series currently include MiniCPM-Llama3-V 2.5 and MiniCPM-V 2.6. The following sections provide detailed tutorials and guidance for each version of the MiniCPM-V models.
|
||||
|
||||
|
||||
## MiniCPM-V 2.6
|
||||
|
||||
MiniCPM-V 2.6 is the latest and most capable model in the MiniCPM-V series. With a total of 8B parameters, the model **surpasses GPT-4V in single image, multi-image and video understanding**. It outperforms **GPT-4o mini, Gemini 1.5 Pro and Claude 3.5 Sonnet** in single image understanding, and advances MiniCPM-Llama3-V 2.5's features such as strong OCR capability, trustworthy behavior, multilingual support, and end-side deployment. Due to its superior token density, MiniCPM-V 2.6 can for the first time support real-time video understanding on end-side devices such as iPad.
|
||||
|
||||
* [Deployment Tutorial](https://modelbest.feishu.cn/wiki/C2BWw4ZP0iCDy7kkCPCcX2BHnOf)
|
||||
* [Training Tutorial](https://modelbest.feishu.cn/wiki/GeHMwLMa0i2FhUkV0f6cz3HWnV1)
|
||||
* [Quantization Tutorial](https://modelbest.feishu.cn/wiki/YvsPwnPwWiqUjlkmW0scQ76TnBb)
|
||||
|
||||
## MiniCPM-Llama3-V 2.5
|
||||
|
||||
MiniCPM-Llama3-V 2.5 is built on SigLip-400M and Llama3-8B-Instruct with a total of 8B parameters. It exhibits a significant performance improvement over MiniCPM-V 2.0.
|
||||
|
||||
* [Quantization Tutorial](https://modelbest.feishu.cn/wiki/Kc7ywV4X1ipSaAkuPFOc9SFun8b)
|
||||
* [Training Tutorial](https://modelbest.feishu.cn/wiki/UpSiw63o9iGDhIklmwScX4a6nhW)
|
||||
* [End-side Deployment](https://modelbest.feishu.cn/wiki/Lwr9wpOQdinr6AkLzHrc9LlgnJD)
|
||||
* [Deployment Tutorial](https://modelbest.feishu.cn/wiki/LTOKw3Hz7il9kGkCLX9czsennKe)
|
||||
* [HD Decoding Tutorial](https://modelbest.feishu.cn/wiki/Ug8iwdXfhiHVsDk2gGEco6xnnVg)
|
||||
* [Model Structure](https://modelbest.feishu.cn/wiki/ACtAw9bOgiBQ9lkWyafcvtVEnQf)
|
||||
22
docs/best_practice_summary_zh.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# MiniCPM-V 最佳实践
|
||||
|
||||
**MiniCPM-V**是面向图文理解的端侧多模态大模型系列。该系列模型接受图像和文本输入,并提供高质量的文本输出。自2024年2月以来,我们共发布了5个版本模型,旨在实现**领先的性能和高效的部署**,目前该系列最值得关注的模型包括:
|
||||
|
||||
## MiniCPM-V 2.6
|
||||
|
||||
MiniCPM-V系列的最新、性能最佳模型。总参数量 8B,单图、多图和视频理解性能**超越了 GPT-4V**。在单图理解上,它取得了优于 **GPT-4o mini、Gemini 1.5 Pro 和 Claude 3.5 Sonnet** 等商用闭源模型的表现,并进一步优化了 MiniCPM-Llama3-V 2.5 的 OCR、可信行为、多语言支持以及端侧部署等诸多特性。基于其领先的视觉 token 密度,MiniCPM-V 2.6 成为了首个支持在 iPad 等端侧设备上进行实时视频理解的多模态大模型。
|
||||
|
||||
* [部署教程](https://modelbest.feishu.cn/wiki/LZxLwp4Lzi29vXklYLFchwN5nCf)
|
||||
* [训练教程](https://modelbest.feishu.cn/wiki/HvfLwYzlIihqzXkmeCdczs6onmd)
|
||||
* [量化教程](https://modelbest.feishu.cn/wiki/PAsHw6N6xiEy0DkJWpJcIocRnz9)
|
||||
|
||||
## MiniCPM-Llama3-V 2.5
|
||||
|
||||
MiniCPM-Llama3-V 2.5 基于 SigLip-400M 和 Llama3-8B-Instruct 构建,总共有 80 亿参数。其性能相比 MiniCPM-V 2.0 有了显著提升。
|
||||
|
||||
* [量化教程](https://modelbest.feishu.cn/wiki/O0KTwQV5piUPzTkRXl9cSFyHnQb)
|
||||
* [训练教程](https://modelbest.feishu.cn/wiki/MPkPwvONEiZm3BkWMnyc83Tin4d)
|
||||
* [端侧部署](https://modelbest.feishu.cn/wiki/CZZJw1EDGitSSZka664cZwbWnrb)
|
||||
* [部署教程](https://modelbest.feishu.cn/wiki/BcHIwjOLGihJXCkkSdMc2WhbnZf)
|
||||
* [高清解码教程](https://modelbest.feishu.cn/wiki/L0ajwm8VAiiPY6kDZfJce3B7nRg)
|
||||
* [模型结构](https://modelbest.feishu.cn/wiki/X15nwGzqpioxlikbi2RcXDpJnjd)
|
||||
445
docs/llamafactory_train_and_infer.md
Normal file
@@ -0,0 +1,445 @@
|
||||
# Best Practice with LLaMA-Factory
|
||||
|
||||
## Contents <!-- omit in toc -->
|
||||
|
||||
- [Support Models](#Support-Models)
|
||||
- [LLaMA-Factory Installation](#LLaMA-Factory-Installation)
|
||||
- [Dataset Prepare](#Dataset-Prepare)
|
||||
- [Image Dataset](#Image-Dataset)
|
||||
- [Video Dataset](#Video-Dataset)
|
||||
- [Audio Dataset](#Audio-Dataset)
|
||||
- [Lora Fine-Tuning](#Lora-Fine-Tuning)
|
||||
- [Full Parameters Fine-Tuning](#Full-Parameters-Fine-Tuning)
|
||||
- [Inference](#Inference)
|
||||
|
||||
## Support Models
|
||||
* [openbmb/MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6)
|
||||
* [openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)
|
||||
|
||||
## LLaMA-Factory Installation
|
||||
|
||||
You can install LLaMA-Factory using commands below.
|
||||
|
||||
```
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics,deepspeed,minicpm_v]"
|
||||
mkdir configs # let's put all yaml files here
|
||||
```
|
||||
|
||||
## Dataset Prepare
|
||||
|
||||
Refer to [data/dataset_info.json](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/dataset_info.json) to add your customised dataset. Let's use the two existing demo datasets `mllm_demo`, `mllm_video_demo` and `mllm_audio_demo` as examples (audio is only for MiniCPM-o-2.6).
|
||||
|
||||
### Image Dataset
|
||||
|
||||
Refer to image sft demo data: [data/mllm_demo.json](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/mllm_demo.json)
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>data/mllm_demo.json</b>
|
||||
</summary>
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<image>Who are they?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "They're Kane and Gretzka from Bayern Munich.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "What are they doing?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "They are celebrating on the soccer field.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/1.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<image>Who is he?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "He's Thomas Muller from Bayern Munich.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "Why is he on the ground?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Because he's sliding on his knees to celebrate.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/2.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<image>Please describe this image",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Chinese astronaut Gui Haichao is giving a speech.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "What has he accomplished?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/3.jpg"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
### Video Dataset
|
||||
|
||||
Refer to video sft demo data: [data/mllm_video_demo.json](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/mllm_video_demo.json)
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>data/mllm_video_demo.json</b>
|
||||
</summary>
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video>Why is this video funny?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Because a baby is reading, and he is so cute!",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/1.mp4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video>What is she doing?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "She is cooking.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/2.avi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video>What's in the video?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "A baby is playing in the living room.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/3.mp4"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Audio Dataset
|
||||
|
||||
Refer to audio sft demo data: [data/mllm_audio_demo.json](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/mllm_audio_demo.json)
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>data/mllm_audio_demo.json</b>
|
||||
</summary>
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<audio>What's that sound?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "It is the sound of glass shattering.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"audios": [
|
||||
"mllm_demo_data/1.mp3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<audio>What can you hear?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "A woman is coughing.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"audios": [
|
||||
"mllm_demo_data/2.wav"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<audio>What does the person say?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Mister Quiller is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"audios": [
|
||||
"mllm_demo_data/3.flac"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Lora Fine-Tuning
|
||||
|
||||
We can use one command to do lora sft:
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train configs/minicpmo_2_6_lora_sft.yaml
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>configs/minicpmo_2_6_lora_sft.yaml</b>
|
||||
</summary>
|
||||
|
||||
```yaml
|
||||
### model
|
||||
model_name_or_path: openbmb/MiniCPM-o-2_6 # MiniCPM-o-2_6 MiniCPM-V-2_6
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo # mllm_demo mllm_video_demo mllm_audio_demo
|
||||
template: minicpm_o # minicpm_o minicpm_v
|
||||
cutoff_len: 3072
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/minicpmo_2_6/lora/sft
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_total_limit: 10
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 20.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
save_only_model: true
|
||||
|
||||
### eval
|
||||
do_eval: false
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Lora Model Export
|
||||
|
||||
One command to export lora model
|
||||
|
||||
```shell
|
||||
llamafactory-cli export configs/minicpmo_2_6_lora_export.yaml
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>configs/minicpmo_2_6_lora_export.yaml</b>
|
||||
</summary>
|
||||
|
||||
```yaml
|
||||
### model
|
||||
model_name_or_path: openbmb/MiniCPM-o-2_6 # MiniCPM-o-2_6 MiniCPM-V-2_6
|
||||
adapter_name_or_path: saves/minicpmo_2_6/lora/sft
|
||||
template: minicpm_o # minicpm_o minicpm_v
|
||||
finetuning_type: lora
|
||||
trust_remote_code: true
|
||||
|
||||
### export
|
||||
export_dir: models/minicpmo_2_6_lora_sft
|
||||
export_size: 2
|
||||
export_device: cpu
|
||||
export_legacy_format: false
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Full Parameters Fine-Tuning
|
||||
|
||||
We can use one command to do full sft:
|
||||
|
||||
```shell
|
||||
llamafactory-cli train configs/minicpmo_2_6_full_sft.yaml
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>configs/minicpmo_2_6_full_sft.yaml</b>
|
||||
</summary>
|
||||
|
||||
```yaml
|
||||
### model
|
||||
model_name_or_path: openbmb/MiniCPM-o-2_6 # MiniCPM-o-2_6 MiniCPM-V-2_6
|
||||
trust_remote_code: true
|
||||
freeze_vision_tower: true
|
||||
print_param_status: true
|
||||
flash_attn: fa2
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
deepspeed: configs/deepspeed/ds_z2_config.json
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo # mllm_demo mllm_video_demo
|
||||
template: minicpm_o # minicpm_o minicpm_v
|
||||
cutoff_len: 3072
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/minicpmo_2_6/full/sft
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_total_limit: 10
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 20.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
save_only_model: true
|
||||
|
||||
### eval
|
||||
do_eval: false
|
||||
```
|
||||
</details>
|
||||
|
||||
## Inference
|
||||
|
||||
### Web UI ChatBox
|
||||
|
||||
Refer [LLaMA-Factory doc](https://github.com/hiyouga/LLaMA-Factory/tree/main/examples#inferring-lora-fine-tuned-models) for more inference usages.
|
||||
|
||||
For example, we can use one command to run web chat:
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat configs/minicpmo_2_6_infer.yaml
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>configs/minicpmo_2_6_infer.yaml</b>
|
||||
</summary>
|
||||
|
||||
```yaml
|
||||
model_name_or_path: saves/minicpmo_2_6/full/sft
|
||||
template: minicpm_o # minicpm_o minicpm_v
|
||||
infer_backend: huggingface
|
||||
trust_remote_code: true
|
||||
```
|
||||
</details>
|
||||
|
||||
### Official Code
|
||||
You can also use official code to inference
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>official inference code</b>
|
||||
</summary>
|
||||
|
||||
```python
|
||||
# test.py
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
model_id = "saves/minicpmo_2_6/full/sft"
|
||||
model = AutoModel.from_pretrained(model_id, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
||||
model = model.eval().cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
image = Image.open('data/mllm_demo_data/1.jpg').convert('RGB')
|
||||
question = 'Who are they??'
|
||||
msgs = [{'role': 'user', 'content': [image, question]}]
|
||||
|
||||
res = model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
print(res)
|
||||
```
|
||||
|
||||
</details>
|
||||
333
docs/minicpm_llama3_v2dot5.md
Normal file
@@ -0,0 +1,333 @@
|
||||
## MiniCPM-Llama3-V 2.5
|
||||
|
||||
> Archieve at: 2025-01-13
|
||||
|
||||
|
||||
**MiniCPM-Llama3-V 2.5** is the latest model in the MiniCPM-V series. The model is built on SigLip-400M and Llama3-8B-Instruct with a total of 8B parameters. It exhibits a significant performance improvement over MiniCPM-V 2.0. Notable features of MiniCPM-Llama3-V 2.5 include:
|
||||
|
||||
- 🔥 **Leading Performance.**
|
||||
MiniCPM-Llama3-V 2.5 has achieved an average score of 65.1 on OpenCompass, a comprehensive evaluation over 11 popular benchmarks. **With only 8B parameters, it surpasses widely used proprietary models like GPT-4V-1106, Gemini Pro, Claude 3 and Qwen-VL-Max** and greatly outperforms other Llama 3-based MLLMs.
|
||||
|
||||
- 💪 **Strong OCR Capabilities.**
|
||||
MiniCPM-Llama3-V 2.5 can process images with any aspect ratio and up to 1.8 million pixels (e.g., 1344x1344), achieving a **700+ score on OCRBench, surpassing proprietary models such as GPT-4o, GPT-4V-0409, Qwen-VL-Max and Gemini Pro**. Based on recent user feedback, MiniCPM-Llama3-V 2.5 has now enhanced full-text OCR extraction, table-to-markdown conversion, and other high-utility capabilities, and has further strengthened its instruction-following and complex reasoning abilities, enhancing multimodal interaction experiences.
|
||||
|
||||
- 🏆 **Trustworthy Behavior.**
|
||||
Leveraging the latest [RLAIF-V](https://github.com/RLHF-V/RLAIF-V/) method (the newest technique in the [RLHF-V](https://github.com/RLHF-V) [CVPR'24] series), MiniCPM-Llama3-V 2.5 exhibits more trustworthy behavior. It achieves a **10.3%** hallucination rate on Object HalBench, lower than GPT-4V-1106 (13.6%), achieving the best-level performance within the open-source community. [Data released](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset).
|
||||
|
||||
- 🌏 **Multilingual Support.**
|
||||
Thanks to the strong multilingual capabilities of Llama 3 and the cross-lingual generalization technique from [VisCPM](https://github.com/OpenBMB/VisCPM), MiniCPM-Llama3-V 2.5 extends its bilingual (Chinese-English) multimodal capabilities to **over 30 languages including German, French, Spanish, Italian, Korean etc.** [All Supported Languages](./assets/minicpm-llama-v-2-5_languages.md).
|
||||
|
||||
- 🚀 **Efficient Deployment.**
|
||||
MiniCPM-Llama3-V 2.5 systematically employs **model quantization, CPU optimizations, NPU optimizations and compilation optimizations**, achieving high-efficiency deployment on end-side devices. For mobile phones with Qualcomm chips, we have integrated the NPU acceleration framework QNN into llama.cpp for the first time. After systematic optimization, MiniCPM-Llama3-V 2.5 has realized a **150x acceleration in end-side MLLM image encoding** and a **3x speedup in language decoding**.
|
||||
|
||||
- 💫 **Easy Usage.**
|
||||
MiniCPM-Llama3-V 2.5 can be easily used in various ways: (1) [llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpm-v2.5/examples/minicpmv/README.md) and [ollama](https://github.com/OpenBMB/ollama/tree/minicpm-v2.5/examples/minicpm-v2.5) support for efficient CPU inference on local devices, (2) [GGUF](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) format quantized models in 16 sizes, (3) efficient [LoRA](https://github.com/OpenBMB/MiniCPM-V/tree/main/finetune#lora-finetuning) fine-tuning with only 2 V100 GPUs, (4) [streaming output](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5#usage), (5) quick local WebUI demo setup with [Gradio](https://github.com/OpenBMB/MiniCPM-V/blob/main/web_demo_2.5.py) and [Streamlit](https://github.com/OpenBMB/MiniCPM-V/blob/main/web_demo_streamlit-2_5.py), and (6) interactive demos on [HuggingFace Spaces](https://huggingface.co/spaces/openbmb/MiniCPM-Llama3-V-2_5).
|
||||
|
||||
### Evaluation <!-- omit in toc -->
|
||||
|
||||
<div align="center">
|
||||
<img src=../assets/MiniCPM-Llama3-V-2.5-peformance.png width=66% />
|
||||
</div>
|
||||
<details>
|
||||
<summary>Click to view results on TextVQA, DocVQA, OCRBench, OpenCompass, MME, MMBench, MMMU, MathVista, LLaVA Bench, RealWorld QA, Object HalBench. </summary>
|
||||
<div align="center">
|
||||
|
||||
<table style="margin: 0px auto;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>OCRBench</th>
|
||||
<th>TextVQA val</th>
|
||||
<th>DocVQA test</th>
|
||||
<th>Open-Compass</th>
|
||||
<th>MME</th>
|
||||
<th>MMB test (en)</th>
|
||||
<th>MMB test (cn)</th>
|
||||
<th>MMMU val</th>
|
||||
<th>Math-Vista</th>
|
||||
<th>LLaVA Bench</th>
|
||||
<th>RealWorld QA</th>
|
||||
<th>Object HalBench</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td colspan="14" align="left"><strong>Proprietary</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Gemini Pro</td>
|
||||
<td>-</td>
|
||||
<td>680</td>
|
||||
<td>74.6</td>
|
||||
<td>88.1</td>
|
||||
<td>62.9</td>
|
||||
<td>2148.9</td>
|
||||
<td>73.6</td>
|
||||
<td>74.3</td>
|
||||
<td>48.9</td>
|
||||
<td>45.8</td>
|
||||
<td>79.9</td>
|
||||
<td>60.4</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">GPT-4V (2023.11.06)</td>
|
||||
<td>-</td>
|
||||
<td>645</td>
|
||||
<td>78.0</td>
|
||||
<td>88.4</td>
|
||||
<td>63.5</td>
|
||||
<td>1771.5</td>
|
||||
<td>77.0</td>
|
||||
<td>74.4</td>
|
||||
<td>53.8</td>
|
||||
<td>47.8</td>
|
||||
<td>93.1</td>
|
||||
<td>63.0</td>
|
||||
<td>86.4</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="14" align="left"><strong>Open-source</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Mini-Gemini</td>
|
||||
<td>2.2B</td>
|
||||
<td>-</td>
|
||||
<td>56.2</td>
|
||||
<td>34.2*</td>
|
||||
<td>-</td>
|
||||
<td>1653.0</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>31.7</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Qwen-VL-Chat</td>
|
||||
<td>9.6B</td>
|
||||
<td>488</td>
|
||||
<td>61.5</td>
|
||||
<td>62.6</td>
|
||||
<td>51.6</td>
|
||||
<td>1860.0</td>
|
||||
<td>61.8</td>
|
||||
<td>56.3</td>
|
||||
<td>37.0</td>
|
||||
<td>33.8</td>
|
||||
<td>67.7</td>
|
||||
<td>49.3</td>
|
||||
<td>56.2</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">DeepSeek-VL-7B</td>
|
||||
<td>7.3B</td>
|
||||
<td>435</td>
|
||||
<td>64.7*</td>
|
||||
<td>47.0*</td>
|
||||
<td>54.6</td>
|
||||
<td>1765.4</td>
|
||||
<td>73.8</td>
|
||||
<td>71.4</td>
|
||||
<td>38.3</td>
|
||||
<td>36.8</td>
|
||||
<td>77.8</td>
|
||||
<td>54.2</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Yi-VL-34B</td>
|
||||
<td>34B</td>
|
||||
<td>290</td>
|
||||
<td>43.4*</td>
|
||||
<td>16.9*</td>
|
||||
<td>52.2</td>
|
||||
<td><strong>2050.2</strong></td>
|
||||
<td>72.4</td>
|
||||
<td>70.7</td>
|
||||
<td>45.1</td>
|
||||
<td>30.7</td>
|
||||
<td>62.3</td>
|
||||
<td>54.8</td>
|
||||
<td>79.3</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">CogVLM-Chat</td>
|
||||
<td>17.4B</td>
|
||||
<td>590</td>
|
||||
<td>70.4</td>
|
||||
<td>33.3*</td>
|
||||
<td>54.2</td>
|
||||
<td>1736.6</td>
|
||||
<td>65.8</td>
|
||||
<td>55.9</td>
|
||||
<td>37.3</td>
|
||||
<td>34.7</td>
|
||||
<td>73.9</td>
|
||||
<td>60.3</td>
|
||||
<td>73.6</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">TextMonkey</td>
|
||||
<td>9.7B</td>
|
||||
<td>558</td>
|
||||
<td>64.3</td>
|
||||
<td>66.7</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Idefics2</td>
|
||||
<td>8.0B</td>
|
||||
<td>-</td>
|
||||
<td>73.0</td>
|
||||
<td>74.0</td>
|
||||
<td>57.2</td>
|
||||
<td>1847.6</td>
|
||||
<td>75.7</td>
|
||||
<td>68.6</td>
|
||||
<td>45.2</td>
|
||||
<td>52.2</td>
|
||||
<td>49.1</td>
|
||||
<td>60.7</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Bunny-LLama-3-8B</td>
|
||||
<td>8.4B</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>54.3</td>
|
||||
<td>1920.3</td>
|
||||
<td>77.0</td>
|
||||
<td>73.9</td>
|
||||
<td>41.3</td>
|
||||
<td>31.5</td>
|
||||
<td>61.2</td>
|
||||
<td>58.8</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">LLaVA-NeXT Llama-3-8B</td>
|
||||
<td>8.4B</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>78.2</td>
|
||||
<td>-</td>
|
||||
<td>1971.5</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>41.7</td>
|
||||
<td>37.5</td>
|
||||
<td>80.1</td>
|
||||
<td>60.0</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Phi-3-vision-128k-instruct</td>
|
||||
<td>4.2B</td>
|
||||
<td>639*</td>
|
||||
<td>70.9</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>1537.5*</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>40.4</td>
|
||||
<td>44.5</td>
|
||||
<td>64.2*</td>
|
||||
<td>58.8*</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td nowrap="nowrap" align="left">MiniCPM-V 1.0</td>
|
||||
<td>2.8B</td>
|
||||
<td>366</td>
|
||||
<td>60.6</td>
|
||||
<td>38.2</td>
|
||||
<td>47.5</td>
|
||||
<td>1650.2</td>
|
||||
<td>64.1</td>
|
||||
<td>62.6</td>
|
||||
<td>38.3</td>
|
||||
<td>28.9</td>
|
||||
<td>51.3</td>
|
||||
<td>51.2</td>
|
||||
<td>78.4</td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td nowrap="nowrap" align="left">MiniCPM-V 2.0</td>
|
||||
<td>2.8B</td>
|
||||
<td>605</td>
|
||||
<td>74.1</td>
|
||||
<td>71.9</td>
|
||||
<td>54.5</td>
|
||||
<td>1808.6</td>
|
||||
<td>69.1</td>
|
||||
<td>66.5</td>
|
||||
<td>38.2</td>
|
||||
<td>38.7</td>
|
||||
<td>69.2</td>
|
||||
<td>55.8</td>
|
||||
<td>85.5</td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td nowrap="nowrap" align="left">MiniCPM-Llama3-V 2.5</td>
|
||||
<td>8.5B</td>
|
||||
<td><strong>725</strong></td>
|
||||
<td><strong>76.6</strong></td>
|
||||
<td><strong>84.8</strong></td>
|
||||
<td><strong>65.1</strong></td>
|
||||
<td>2024.6</td>
|
||||
<td><strong>77.2</strong></td>
|
||||
<td><strong>74.2</strong></td>
|
||||
<td><strong>45.8</strong></td>
|
||||
<td><strong>54.3</strong></td>
|
||||
<td><strong>86.7</strong></td>
|
||||
<td><strong>63.5</strong></td>
|
||||
<td><strong>89.7</strong></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
</div>
|
||||
* We evaluate the officially released checkpoint by ourselves.
|
||||
|
||||
</details>
|
||||
|
||||
<div align="center">
|
||||
<img src="../assets/llavabench_compare_3.png" width="100%" />
|
||||
<br>
|
||||
Evaluation results of multilingual LLaVA Bench
|
||||
</div>
|
||||
|
||||
### Examples <!-- omit in toc -->
|
||||
|
||||
<table align="center" >
|
||||
<p align="center" >
|
||||
<img src="../assets/minicpmv-llama3-v2.5/cases_all.png" />
|
||||
</p>
|
||||
</table>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
### Model Zoo
|
||||
|
||||
| Model | Device | Memory |          Description | Download |
|
||||
|:-----------|:--:|:-----------:|:-------------------|:---------------:|
|
||||
| MiniCPM-Llama3-V 2.5 | GPU | 19 GB | Strong end-side multimodal performance. | [🤗](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5/) [<img src="../assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5) |
|
||||
| MiniCPM-Llama3-V 2.5 gguf | CPU | 6 GB | The gguf version, lower memory usage and faster inference. | [🤗](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) [<img src="../assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5-gguf) |
|
||||
| MiniCPM-Llama3-V 2.5 int4 | GPU | 8 GB | The int4 quantized version, lower GPU memory usage. | [🤗](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-int4/) [<img src="../assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5-int4) |
|
||||
299
docs/minicpm_v2.md
Normal file
@@ -0,0 +1,299 @@
|
||||
## MiniCPM-V 2.0
|
||||
|
||||
|
||||
> Archive at:2025-01-13
|
||||
|
||||
|
||||
|
||||
**MiniCPM-V 2.0** is an efficient version with promising performance for deployment. The model is built based on SigLip-400M and [MiniCPM-2.4B](https://github.com/OpenBMB/MiniCPM/), connected by a perceiver resampler. Our latest version, MiniCPM-V 2.0 has several notable features.
|
||||
|
||||
- 🔥 **State-of-the-art Performance.**
|
||||
|
||||
MiniCPM-V 2.0 achieves **state-of-the-art performance** on multiple benchmarks (including OCRBench, TextVQA, MME, MMB, MathVista, etc) among models under 7B parameters. It even **outperforms strong Qwen-VL-Chat 9.6B, CogVLM-Chat 17.4B, and Yi-VL 34B on OpenCompass, a comprehensive evaluation over 11 popular benchmarks**. Notably, MiniCPM-V 2.0 shows **strong OCR capability**, achieving **comparable performance to Gemini Pro in scene-text understanding**, and **state-of-the-art performance on OCRBench** among open-source models.
|
||||
|
||||
- 🏆 **Trustworthy Behavior.**
|
||||
|
||||
LMMs are known for suffering from hallucination, often generating text not factually grounded in images. MiniCPM-V 2.0 is **the first end-side LMM aligned via multimodal RLHF for trustworthy behavior** (using the recent [RLHF-V](https://rlhf-v.github.io/) [CVPR'24] series technique). This allows the model to **match GPT-4V in preventing hallucinations** on Object HalBench.
|
||||
|
||||
- 🌟 **High-Resolution Images at Any Aspect Raito.**
|
||||
|
||||
MiniCPM-V 2.0 can accept **1.8 million pixels (e.g., 1344x1344) images at any aspect ratio**. This enables better perception of fine-grained visual information such as small objects and optical characters, which is achieved via a recent technique from [LLaVA-UHD](https://arxiv.org/pdf/2403.11703.pdf).
|
||||
|
||||
- ⚡️ **High Efficiency.**
|
||||
|
||||
MiniCPM-V 2.0 can be **efficiently deployed on most GPU cards and personal computers**, and **even on end devices such as mobile phones**. For visual encoding, we compress the image representations into much fewer tokens via a perceiver resampler. This allows MiniCPM-V 2.0 to operate with **favorable memory cost and speed during inference even when dealing with high-resolution images**.
|
||||
|
||||
- 🙌 **Bilingual Support.**
|
||||
|
||||
MiniCPM-V 2.0 **supports strong bilingual multimodal capabilities in both English and Chinese**. This is enabled by generalizing multimodal capabilities across languages, a technique from [VisCPM](https://arxiv.org/abs/2308.12038) [ICLR'24].
|
||||
|
||||
|
||||
### Evaluation <!-- omit in toc -->
|
||||
|
||||
<div align="center">
|
||||
<img src=../assets/minicpmv-2-peformance.png width=66% />
|
||||
</div>
|
||||
<details>
|
||||
<summary>Click to view results on TextVQA, DocVQA, OCRBench, OpenCompass, MME, MMBench, MMMU, MathVista, LLaVA Bench, Object HalBench. </summary>
|
||||
<div align="center">
|
||||
|
||||
<table style="margin: 0px auto;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>TextVQA val</th>
|
||||
<th>DocVQA test</th>
|
||||
<th>OCRBench</th>
|
||||
<th>OpenCompass</th>
|
||||
<th nowrap="nowrap" >MME</th>
|
||||
<th>MMB dev(en)</th>
|
||||
<th>MMB dev(zh)</th>
|
||||
<th>MMMU val</th>
|
||||
<th>MathVista</th>
|
||||
<th>LLaVA Bench</th>
|
||||
<th nowrap="nowrap">Object HalBench</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td colspan="12" align="left"><strong>Proprietary models</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Gemini Pro Vision</td>
|
||||
<td>- </td>
|
||||
<td>74.6</td>
|
||||
<td>88.1</td>
|
||||
<td>680</td>
|
||||
<td>63.8</td>
|
||||
<td>2148.9</td>
|
||||
<td>75.2</td>
|
||||
<td>74.0</td>
|
||||
<td>48.9</td>
|
||||
<td>45.8</td>
|
||||
<td>79.9</td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">GPT-4V</td>
|
||||
<td>- </td>
|
||||
<td>78.0</td>
|
||||
<td>88.4</td>
|
||||
<td>645</td>
|
||||
<td>63.2</td>
|
||||
<td>1771.5</td>
|
||||
<td>75.1</td>
|
||||
<td>75.0</td>
|
||||
<td>53.8</td>
|
||||
<td>47.8</td>
|
||||
<td>93.1</td>
|
||||
<td>86.4 / 92.7</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="12" align="left"><strong>Open-source models 6B~34B</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >Yi-VL-6B</td>
|
||||
<td align="right" >6.7B</td>
|
||||
<td>45.5*</td>
|
||||
<td>17.1*</td>
|
||||
<td>290</td>
|
||||
<td>49.3</td>
|
||||
<td>1915.1 </td>
|
||||
<td>68.6 </td>
|
||||
<td>68.3 </td>
|
||||
<td>40.3 </td>
|
||||
<td>28.8 </td>
|
||||
<td>51.9 </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >Qwen-VL-Chat</td>
|
||||
<td align="right" >9.6B</td>
|
||||
<td>61.5</td>
|
||||
<td>62.6</td>
|
||||
<td>488 </td>
|
||||
<td>52.1 </td>
|
||||
<td>1860.0 </td>
|
||||
<td>60.6 </td>
|
||||
<td>56.7 </td>
|
||||
<td>37.0 </td>
|
||||
<td>33.8 </td>
|
||||
<td>67.7 </td>
|
||||
<td>56.2 / 80.0</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >Yi-VL-34B</td>
|
||||
<td align="right" >34B</td>
|
||||
<td>43.4*</td>
|
||||
<td>16.9*</td>
|
||||
<td>290</td>
|
||||
<td>52.6 </td>
|
||||
<td>2050.2</td>
|
||||
<td>71.1</td>
|
||||
<td>71.4</td>
|
||||
<td>45.1</td>
|
||||
<td>30.7</td>
|
||||
<td>62.3</td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >DeepSeek-VL-7B</td>
|
||||
<td align="right" >7.3B</td>
|
||||
<td>64.7*</td>
|
||||
<td>47.0* </td>
|
||||
<td>435</td>
|
||||
<td>55.6 </td>
|
||||
<td>1765.4 </td>
|
||||
<td>74.1 </td>
|
||||
<td>72.8 </td>
|
||||
<td>38.3 </td>
|
||||
<td>36.8</td>
|
||||
<td>77.8 </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >TextMonkey</td>
|
||||
<td align="right" >9.7B</td>
|
||||
<td>64.3</td>
|
||||
<td>66.7 </td>
|
||||
<td>558</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
<td>-</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >CogVLM-Chat</td>
|
||||
<td align="right" >17.4B</td>
|
||||
<td>70.4</td>
|
||||
<td>33.3*</td>
|
||||
<td>590 </td>
|
||||
<td>52.5 </td>
|
||||
<td>1736.6 </td>
|
||||
<td>63.7 </td>
|
||||
<td>53.8 </td>
|
||||
<td>37.3 </td>
|
||||
<td>34.7 </td>
|
||||
<td>73.9 </td>
|
||||
<td>73.6 / 87.4 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="12" align="left"><strong>Open-source models 1B~3B </strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >DeepSeek-VL-1.3B</td>
|
||||
<td align="right" >1.7B</td>
|
||||
<td>58.4*</td>
|
||||
<td>37.9*</td>
|
||||
<td>413</td>
|
||||
<td>46.0 </td>
|
||||
<td>1531.6 </td>
|
||||
<td>64.0 </td>
|
||||
<td>61.2 </td>
|
||||
<td>33.8 </td>
|
||||
<td>29.4 </td>
|
||||
<td>51.1 </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >MobileVLM V2</td>
|
||||
<td align="right" >3.1B</td>
|
||||
<td>57.5</td>
|
||||
<td>19.4*</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>1440.5(P) </td>
|
||||
<td>63.2 </td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >Mini-Gemini</td>
|
||||
<td align="right" >2.2B</td>
|
||||
<td>56.2</td>
|
||||
<td>34.2*</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>1653.0 </td>
|
||||
<td>59.8 </td>
|
||||
<td>- </td>
|
||||
<td>31.7 </td>
|
||||
<td>-</td>
|
||||
<td>- </td>
|
||||
<td>- </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" >MiniCPM-V</td>
|
||||
<td align="right" >2.8B </td>
|
||||
<td>60.6</td>
|
||||
<td>38.2 </td>
|
||||
<td>366</td>
|
||||
<td>47.6</td>
|
||||
<td>1650.2 </td>
|
||||
<td>67.9 </td>
|
||||
<td>65.3 </td>
|
||||
<td><strong>38.3</strong></td>
|
||||
<td>28.9</td>
|
||||
<td>51.3 </td>
|
||||
<td>78.4 / 88.5 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left" ><strong>MiniCPM-V 2.0</strong></td>
|
||||
<td align="right" >2.8B </td>
|
||||
<td><strong>74.1</strong></td>
|
||||
<td><strong>71.9</strong> </td>
|
||||
<td><strong>605</strong></td>
|
||||
<td><strong>55.0</strong></td>
|
||||
<td><strong>1808.6</strong> </td>
|
||||
<td><strong>69.6</strong> </td>
|
||||
<td><strong>68.1</strong> </td>
|
||||
<td>38.2 </td>
|
||||
<td><strong>38.7</strong></td>
|
||||
<td><strong>69.2</strong> </td>
|
||||
<td><strong>85.5 / 92.2 </strong></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
</div>
|
||||
* We evaluate the officially released checkpoint by ourselves.
|
||||
</details>
|
||||
|
||||
### Examples <!-- omit in toc -->
|
||||
|
||||
<table align="center">
|
||||
<p align="center">
|
||||
<img src="../assets/minicpmv2-cases_2.png" width=95%/>
|
||||
</p>
|
||||
</table>
|
||||
|
||||
We deploy MiniCPM-V 2.0 on end devices. The demo video is the raw screen recording on a Xiaomi 14 Pro without edition.
|
||||
|
||||
<table align="center">
|
||||
<p align="center">
|
||||
<img src="../assets/gif_cases/station.gif" width=36%/>
|
||||
<img src="../assets/gif_cases/london_car.gif" width=36%/>
|
||||
</p>
|
||||
</table>
|
||||
|
||||
|
||||
|
||||
### Model Zoo
|
||||
|
||||
| Model | Device | Memory |          Description | Download |
|
||||
|:-----------|:--:|:-----------:|:-------------------|:---------------:|
|
||||
| MiniCPM-V 2.0 | GPU | 8 GB | Light version, balance the performance the computation cost. | [🤗](https://huggingface.co/openbmb/MiniCPM-V-2) [<img src="../assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2) |
|
||||
| MiniCPM-V 1.0 | GPU | 7 GB | Lightest version, achieving the fastest inference. | [🤗](https://huggingface.co/openbmb/MiniCPM-V) [<img src="../assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V) |
|
||||
|
||||
|
||||
### Deployment on Mobile Phone
|
||||
|
||||
MiniCPM-V 2.0 can be deployed on mobile phones with Android operating systems. 🚀 Click [MiniCPM-V 2.0](https://github.com/OpenBMB/mlc-MiniCPM) to install apk.
|
||||
945
docs/minicpm_v2dot6.md
Normal file
@@ -0,0 +1,945 @@
|
||||
## MiniCPM-V 2.6
|
||||
|
||||
> Archieve at: 2025-01-13
|
||||
|
||||
**MiniCPM-V 2.6** is the latest and most capable model in the MiniCPM-V series. The model is built on SigLip-400M and Qwen2-7B with a total of 8B parameters. It exhibits a significant performance improvement over MiniCPM-Llama3-V 2.5, and introduces new features for multi-image and video understanding. Notable features of MiniCPM-V 2.6 include:
|
||||
|
||||
- 🔥 **Leading Performance.**
|
||||
MiniCPM-V 2.6 achieves an average score of 65.2 on the latest version of OpenCompass, a comprehensive evaluation over 8 popular benchmarks. **With only 8B parameters, it surpasses widely used proprietary models like GPT-4o mini, GPT-4V, Gemini 1.5 Pro, and Claude 3.5 Sonnet** for single image understanding.
|
||||
|
||||
- 🖼️ **Multi Image Understanding and In-context Learning.** MiniCPM-V 2.6 can also perform **conversation and reasoning over multiple images**. It achieves **state-of-the-art performance** on popular multi-image benchmarks such as Mantis-Eval, BLINK, Mathverse mv and Sciverse mv, and also shows promising in-context learning capability.
|
||||
|
||||
- 🎬 **Video Understanding.** MiniCPM-V 2.6 can also **accept video inputs**, performing conversation and providing dense captions for spatial-temporal information. It outperforms **GPT-4V, Claude 3.5 Sonnet and LLaVA-NeXT-Video-34B** on Video-MME with/without subtitles.
|
||||
|
||||
- 💪 **Strong OCR Capability and Others.**
|
||||
MiniCPM-V 2.6 can process images with any aspect ratio and up to 1.8 million pixels (e.g., 1344x1344). It achieves **state-of-the-art performance on OCRBench, surpassing proprietary models such as GPT-4o, GPT-4V, and Gemini 1.5 Pro**.
|
||||
Based on the the latest [RLAIF-V](https://github.com/RLHF-V/RLAIF-V/) and [VisCPM](https://github.com/OpenBMB/VisCPM) techniques, it features **trustworthy behaviors**, with significantly lower hallucination rates than GPT-4o and GPT-4V on Object HalBench, and supports **multilingual capabilities** on English, Chinese, German, French, Italian, Korean, etc.
|
||||
|
||||
|
||||
- 🚀 **Superior Efficiency.**
|
||||
In addition to its friendly size, MiniCPM-V 2.6 also shows **state-of-the-art token density** (i.e., number of pixels encoded into each visual token). **It produces only 640 tokens when processing a 1.8M pixel image, which is 75% fewer than most models**. This directly improves the inference speed, first-token latency, memory usage, and power consumption. As a result, MiniCPM-V 2.6 can efficiently support **real-time video understanding** on end-side devices such as iPad.
|
||||
|
||||
- 💫 **Easy Usage.**
|
||||
MiniCPM-V 2.6 can be easily used in various ways: (1) [llama.cpp](https://github.com/OpenBMB/llama.cpp/blob/minicpmv-main/examples/llava/README-minicpmv2.6.md) and [ollama](https://github.com/OpenBMB/ollama/blob/minicpm-v2.6/examples/minicpm-v2.6/README.md) support for efficient CPU inference on local devices, (2) [int4](https://huggingface.co/openbmb/MiniCPM-V-2_6-int4) and [GGUF](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) format quantized models in 16 sizes, (3) [vLLM](#inference-with-vllm) support for high-throughput and memory-efficient inference, (4) fine-tuning on new domains and tasks, (5) quick local WebUI demo setup with [Gradio](#chat-with-our-demo-on-gradio), and (6) online web [demo](http://120.92.209.146:8887/).
|
||||
|
||||
### Evaluation <!-- omit in toc -->
|
||||
<div align="center">
|
||||
<img src=../assets/radar_final.png width=66% />
|
||||
</div>
|
||||
|
||||
<details>
|
||||
<summary>Click to view single image results on OpenCompass, MME, MMVet, OCRBench, MMMU, MathVista, MMB, AI2D, TextVQA, DocVQA, HallusionBench, Object HalBench. </summary>
|
||||
<div align="center">
|
||||
|
||||
<table style="margin: 0px auto;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>Token Density<sup>+</sup></th>
|
||||
<th>OpenCompass</th>
|
||||
<th>MME</th>
|
||||
<th>MMVet</th>
|
||||
<th>OCRBench</th>
|
||||
<th>MMMU val</th>
|
||||
<th>MathVista mini</th>
|
||||
<th>MMB1.1 test</th>
|
||||
<th>AI2D</th>
|
||||
<th>TextVQA val</th>
|
||||
<th>DocVQA test</th>
|
||||
<th>HallusionBench</th>
|
||||
<th>Object HalBench</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td colspan="15" align="left"><strong>Proprietary</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">GPT-4o</td>
|
||||
<td>-</td>
|
||||
<td>1088</td>
|
||||
<td>69.9</td>
|
||||
<td>2328.7</td>
|
||||
<td>69.1</td>
|
||||
<td>736</td>
|
||||
<td>69.2</td>
|
||||
<td>61.3</td>
|
||||
<td>82.2</td>
|
||||
<td>84.6</td>
|
||||
<td>-</td>
|
||||
<td>92.8</td>
|
||||
<td>55.0</td>
|
||||
<td>17.6</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Claude 3.5 Sonnet</td>
|
||||
<td>-</td>
|
||||
<td>750</td>
|
||||
<td>67.9</td>
|
||||
<td>1920.0</td>
|
||||
<td>66.0</td>
|
||||
<td>788</td>
|
||||
<td>65.9</td>
|
||||
<td>61.6</td>
|
||||
<td>78.5</td>
|
||||
<td>80.2</td>
|
||||
<td>-</td>
|
||||
<td>95.2</td>
|
||||
<td>49.9</td>
|
||||
<td>13.8</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Gemini 1.5 Pro</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>64.4</td>
|
||||
<td>2110.6</td>
|
||||
<td>64.0</td>
|
||||
<td>754</td>
|
||||
<td>60.6</td>
|
||||
<td>57.7</td>
|
||||
<td>73.9</td>
|
||||
<td>79.1</td>
|
||||
<td>73.5</td>
|
||||
<td>86.5</td>
|
||||
<td>45.6</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">GPT-4o mini</td>
|
||||
<td>-</td>
|
||||
<td>1088</td>
|
||||
<td>64.1</td>
|
||||
<td>2003.4</td>
|
||||
<td>66.9</td>
|
||||
<td>785</td>
|
||||
<td>60.0</td>
|
||||
<td>52.4</td>
|
||||
<td>76.0</td>
|
||||
<td>77.8</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>46.1</td>
|
||||
<td>12.4</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">GPT-4V</td>
|
||||
<td>-</td>
|
||||
<td>1088</td>
|
||||
<td>63.5</td>
|
||||
<td>2070.2</td>
|
||||
<td>67.5</td>
|
||||
<td>656</td>
|
||||
<td>61.7</td>
|
||||
<td>54.7</td>
|
||||
<td>79.8</td>
|
||||
<td>78.6</td>
|
||||
<td>78.0</td>
|
||||
<td>87.2</td>
|
||||
<td>43.9</td>
|
||||
<td>14.2</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Step-1V</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>59.5</td>
|
||||
<td>2206.4</td>
|
||||
<td>63.3</td>
|
||||
<td>625</td>
|
||||
<td>49.9</td>
|
||||
<td>44.8</td>
|
||||
<td>78.0</td>
|
||||
<td>79.2</td>
|
||||
<td>71.6</td>
|
||||
<td>-</td>
|
||||
<td>48.4</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Qwen-VL-Max</td>
|
||||
<td>-</td>
|
||||
<td>784</td>
|
||||
<td>58.3</td>
|
||||
<td>2281.7</td>
|
||||
<td>61.8</td>
|
||||
<td>684</td>
|
||||
<td>52.0</td>
|
||||
<td>43.4</td>
|
||||
<td>74.6</td>
|
||||
<td>75.7</td>
|
||||
<td>79.5</td>
|
||||
<td>93.1</td>
|
||||
<td>41.2</td>
|
||||
<td>13.4</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="15" align="left"><strong>Open-source</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">LLaVA-NeXT-Yi-34B</td>
|
||||
<td>34B</td>
|
||||
<td>157</td>
|
||||
<td>55.0</td>
|
||||
<td>2006.5</td>
|
||||
<td>50.7</td>
|
||||
<td>574</td>
|
||||
<td>48.8</td>
|
||||
<td>40.4</td>
|
||||
<td>77.8</td>
|
||||
<td>78.9</td>
|
||||
<td>69.3</td>
|
||||
<td>-</td>
|
||||
<td>34.8</td>
|
||||
<td>12.6</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Mini-Gemini-HD-34B</td>
|
||||
<td>34B</td>
|
||||
<td>157</td>
|
||||
<td>-</td>
|
||||
<td>2141.0</td>
|
||||
<td>59.3</td>
|
||||
<td>518</td>
|
||||
<td>48.0</td>
|
||||
<td>43.3</td>
|
||||
<td>-</td>
|
||||
<td>80.5</td>
|
||||
<td>74.1</td>
|
||||
<td>78.9</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Cambrian-34B</td>
|
||||
<td>34B</td>
|
||||
<td>1820</td>
|
||||
<td>58.3</td>
|
||||
<td>2049.9</td>
|
||||
<td>53.2</td>
|
||||
<td>591</td>
|
||||
<td>50.4</td>
|
||||
<td>50.3</td>
|
||||
<td>77.8</td>
|
||||
<td>79.5</td>
|
||||
<td>76.7</td>
|
||||
<td>75.5</td>
|
||||
<td>41.6</td>
|
||||
<td>14.7</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">GLM-4V-9B</td>
|
||||
<td>13B</td>
|
||||
<td>784</td>
|
||||
<td>59.1</td>
|
||||
<td>2018.8</td>
|
||||
<td>58.0</td>
|
||||
<td>776</td>
|
||||
<td>46.9</td>
|
||||
<td>51.1</td>
|
||||
<td>67.9</td>
|
||||
<td>71.2</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>45.0</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">InternVL2-8B</td>
|
||||
<td>8B</td>
|
||||
<td>706</td>
|
||||
<td>64.1</td>
|
||||
<td>2215.1</td>
|
||||
<td>54.3</td>
|
||||
<td>794</td>
|
||||
<td><strong>51.2</strong></td>
|
||||
<td>58.3</td>
|
||||
<td><strong>79.4</strong></td>
|
||||
<td><strong>83.6</strong></td>
|
||||
<td>77.4</td>
|
||||
<td><strong>91.6</strong></td>
|
||||
<td>45.0</td>
|
||||
<td>21.3</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">MiniCPM-Llama-V 2.5</td>
|
||||
<td>8B</td>
|
||||
<td>1882</td>
|
||||
<td>58.8</td>
|
||||
<td>2024.6</td>
|
||||
<td>52.8</td>
|
||||
<td>725</td>
|
||||
<td>45.8</td>
|
||||
<td>54.3</td>
|
||||
<td>72.0</td>
|
||||
<td>78.4</td>
|
||||
<td>76.6</td>
|
||||
<td>84.8</td>
|
||||
<td>42.4</td>
|
||||
<td>10.3</td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
||||
<td>8B</td>
|
||||
<td><strong>2822</strong></td>
|
||||
<td><strong>65.2</strong></td>
|
||||
<td><strong>2348.4</strong>*</td>
|
||||
<td><strong>60.0</strong></td>
|
||||
<td><strong>852</strong>*</td>
|
||||
<td>49.8*</td>
|
||||
<td><strong>60.6</strong></td>
|
||||
<td>78.0</td>
|
||||
<td>82.1</td>
|
||||
<td><strong>80.1<strong></td>
|
||||
<td>90.8</td>
|
||||
<td><strong>48.1</strong>*</td>
|
||||
<td><strong>8.2</strong></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
</div>
|
||||
* We evaluate this benchmark using chain-of-thought prompting. Specifically, for MME, we used this technique only for the Cognition set.
|
||||
|
||||
<sup>+</sup> Token Density: number of pixels encoded into each visual token at maximum resolution, i.e., # pixels at maximum resolution / # visual tokens.
|
||||
|
||||
Note: For proprietary models, we calculate token density based on the image encoding charging strategy defined in the official API documentation, which provides an upper-bound estimation.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Click to view multi-image results on Mantis Eval, BLINK, Mathverse mv, Sciverse mv, MIRB.</summary>
|
||||
<div align="center">
|
||||
|
||||
<table style="margin: 0px auto;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>Mantis Eval</th>
|
||||
<th>BLINK val</th>
|
||||
<th>Mathverse mv</th>
|
||||
<th>Sciverse mv</th>
|
||||
<th>MIRB</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td colspan="7" align="left"><strong>Proprietary</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">GPT-4V</td>
|
||||
<td>-</td>
|
||||
<td>62.7</td>
|
||||
<td>54.6</td>
|
||||
<td>60.3</td>
|
||||
<td>66.9</td>
|
||||
<td>53.1</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">LLaVA-NeXT-Interleave-14B</td>
|
||||
<td>14B</td>
|
||||
<td>66.4</td>
|
||||
<td>52.6</td>
|
||||
<td>32.7</td>
|
||||
<td>30.2</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="7" align="left"><strong>Open-source</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Emu2-Chat</td>
|
||||
<td>37B</td>
|
||||
<td>37.8</td>
|
||||
<td>36.2</td>
|
||||
<td>-</td>
|
||||
<td>27.2</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">CogVLM</td>
|
||||
<td>17B</td>
|
||||
<td>45.2</td>
|
||||
<td>41.1</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">VPG-C</td>
|
||||
<td>7B</td>
|
||||
<td>52.4</td>
|
||||
<td>43.1</td>
|
||||
<td>24.3</td>
|
||||
<td>23.1</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">VILA 8B</td>
|
||||
<td>8B</td>
|
||||
<td>51.2</td>
|
||||
<td>39.3</td>
|
||||
<td>-</td>
|
||||
<td>36.5</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">InternLM-XComposer-2.5</td>
|
||||
<td>8B</td>
|
||||
<td>53.1*</td>
|
||||
<td>48.9</td>
|
||||
<td>32.1*</td>
|
||||
<td>-</td>
|
||||
<td>42.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">InternVL2-8B</td>
|
||||
<td>8B</td>
|
||||
<td>59.0*</td>
|
||||
<td>50.9</td>
|
||||
<td>30.5*</td>
|
||||
<td>34.4*</td>
|
||||
<td><strong>56.9*</strong></td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
||||
<td>8B</td>
|
||||
<td><strong>69.1</strong></td>
|
||||
<td><strong>53.0</strong></td>
|
||||
<td><strong>84.9</strong></td>
|
||||
<td><strong>74.9</strong></td>
|
||||
<td>53.8</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
</div>
|
||||
* We evaluate the officially released checkpoint by ourselves.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Click to view video results on Video-MME and Video-ChatGPT.</summary>
|
||||
<div align="center">
|
||||
<table style="margin: 0px auto;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th colspan="2">Video-MME</th>
|
||||
<th colspan="5">Video-ChatGPT</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th align="left"></th>
|
||||
<th></th>
|
||||
<th>w/o subs</th>
|
||||
<th>w subs</th>
|
||||
<th>Correctness</th>
|
||||
<th>Detail</th>
|
||||
<th>Context</th>
|
||||
<th>Temporal</th>
|
||||
<th>Consistency</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td colspan="9" align="left"><strong>Proprietary</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">Claude 3.5 Sonnet</td>
|
||||
<td>-</td>
|
||||
<td>60.0</td>
|
||||
<td>62.9</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">GPT-4V</td>
|
||||
<td>-</td>
|
||||
<td>59.9</td>
|
||||
<td>63.3</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="9" align="left"><strong>Open-source</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">LLaVA-NeXT-7B</td>
|
||||
<td>7B</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>3.39</td>
|
||||
<td>3.29</td>
|
||||
<td>3.92</td>
|
||||
<td>2.60</td>
|
||||
<td>3.12</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">LLaVA-NeXT-34B</td>
|
||||
<td>34B</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>3.29</td>
|
||||
<td>3.23</td>
|
||||
<td>3.83</td>
|
||||
<td>2.51</td>
|
||||
<td>3.47</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">CogVLM2-Video</td>
|
||||
<td>12B</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>3.49</td>
|
||||
<td><strong>3.46</strong></td>
|
||||
<td>3.23</td>
|
||||
<td><strong>2.98</strong></td>
|
||||
<td><strong>3.64</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">LongVA</td>
|
||||
<td>7B</td>
|
||||
<td>52.4</td>
|
||||
<td>54.3</td>
|
||||
<td>3.05</td>
|
||||
<td>3.09</td>
|
||||
<td>3.77</td>
|
||||
<td>2.44</td>
|
||||
<td><strong>3.64</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">InternVL2-8B</td>
|
||||
<td>8B</td>
|
||||
<td>54.0</td>
|
||||
<td>56.9</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">InternLM-XComposer-2.5</td>
|
||||
<td>8B</td>
|
||||
<td>55.8</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
<td>-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="left">LLaVA-NeXT-Video</td>
|
||||
<td>32B</td>
|
||||
<td>60.2</td>
|
||||
<td>63.0</td>
|
||||
<td>3.48</td>
|
||||
<td>3.37</td>
|
||||
<td><strong>3.95</strong></td>
|
||||
<td>2.64</td>
|
||||
<td>3.28</td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td nowrap="nowrap" align="left">MiniCPM-V 2.6</td>
|
||||
<td>8B</td>
|
||||
<td><strong>60.9</strong></td>
|
||||
<td><strong>63.6</strong></td>
|
||||
<td><strong>3.59</strong></td>
|
||||
<td>3.28</td>
|
||||
<td>3.93</td>
|
||||
<td>2.73</td>
|
||||
<td>3.62</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Click to view few-shot results on TextVQA, VizWiz, VQAv2, OK-VQA.</summary>
|
||||
<div align="center">
|
||||
<table style="margin: 0px auto;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th align="left">Model</th>
|
||||
<th>Size</th>
|
||||
<th>Shot</th>
|
||||
<th>TextVQA val</th>
|
||||
<th>VizWiz test-dev</th>
|
||||
<th>VQAv2 test-dev</th>
|
||||
<th>OK-VQA val</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody align="center">
|
||||
<tr>
|
||||
<td align="left" nowrap="nowrap" rowspan="3">Flamingo</td>
|
||||
<td rowspan="3">80B</td>
|
||||
<td>0*</td>
|
||||
<td>35.0</td>
|
||||
<td>31.6</td>
|
||||
<td>56.3</td>
|
||||
<td>40.6</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>4</td>
|
||||
<td>36.5</td>
|
||||
<td>39.6</td>
|
||||
<td>63.1</td>
|
||||
<td><strong>57.4</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>8</td>
|
||||
<td>37.3</td>
|
||||
<td>44.8</td>
|
||||
<td>65.6</td>
|
||||
<td>57.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" nowrap="nowrap" rowspan="3">IDEFICS</td>
|
||||
<td rowspan="3">80B</td>
|
||||
<td>0*</td>
|
||||
<td>30.9</td>
|
||||
<td>36.0</td>
|
||||
<td>60.0</td>
|
||||
<td>45.2</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>4</td>
|
||||
<td>34.3</td>
|
||||
<td>40.4</td>
|
||||
<td>63.6</td>
|
||||
<td>52.4</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>8</td>
|
||||
<td>35.7</td>
|
||||
<td>46.1</td>
|
||||
<td>64.8</td>
|
||||
<td>55.1</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" nowrap="nowrap" rowspan="3">OmniCorpus</td>
|
||||
<td rowspan="3">7B</td>
|
||||
<td>0*</td>
|
||||
<td>43.0</td>
|
||||
<td>49.8</td>
|
||||
<td>63.2</td>
|
||||
<td>45.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>4</td>
|
||||
<td>45.4</td>
|
||||
<td>51.3</td>
|
||||
<td>64.5</td>
|
||||
<td>46.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>8</td>
|
||||
<td>45.6</td>
|
||||
<td>52.2</td>
|
||||
<td>64.7</td>
|
||||
<td>46.6</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" nowrap="nowrap" rowspan="3">Emu2</td>
|
||||
<td rowspan="3">37B</td>
|
||||
<td>0</td>
|
||||
<td>26.4</td>
|
||||
<td>40.4</td>
|
||||
<td>33.5</td>
|
||||
<td>26.7</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>4</td>
|
||||
<td>48.2</td>
|
||||
<td>54.6</td>
|
||||
<td>67.0</td>
|
||||
<td>53.2</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>8</td>
|
||||
<td>49.3</td>
|
||||
<td>54.7</td>
|
||||
<td>67.8</td>
|
||||
<td>54.1</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left" nowrap="nowrap" rowspan="2">MM1</td>
|
||||
<td rowspan="2">30B</td>
|
||||
<td>0</td>
|
||||
<td>26.2</td>
|
||||
<td>40.4</td>
|
||||
<td>48.9</td>
|
||||
<td>26.7</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>8</td>
|
||||
<td>49.3</td>
|
||||
<td>54.7</td>
|
||||
<td><strong>70.9</strong></td>
|
||||
<td>54.1</td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td align="left" nowrap="nowrap" rowspan="3">MiniCPM-V 2.6<sup>+</sup></td>
|
||||
<td rowspan="3">8B</td>
|
||||
<td>0</td>
|
||||
<td>43.9</td>
|
||||
<td>33.8</td>
|
||||
<td>45.4</td>
|
||||
<td>23.9</td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td>4</td>
|
||||
<td>63.6</td>
|
||||
<td>60.5</td>
|
||||
<td>65.5</td>
|
||||
<td>50.1</td>
|
||||
</tr>
|
||||
<tr style="background-color: #e6f2ff;">
|
||||
<td>8</td>
|
||||
<td><strong>64.6</strong></td>
|
||||
<td><strong>63.4</strong></td>
|
||||
<td>68.2</td>
|
||||
<td>51.4</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
</div>
|
||||
* denotes zero image shot and two additional text shots following Flamingo.
|
||||
|
||||
<sup>+</sup> We evaluate the pretraining ckpt without SFT.
|
||||
</details>
|
||||
|
||||
### Examples <!-- omit in toc -->
|
||||
|
||||
<div style="display: flex; flex-direction: column; align-items: center;">
|
||||
<img src="../assets/minicpmv2_6/multi_img-bike.png" alt="Bike" style="margin-bottom: 5px;">
|
||||
<img src="../assets/minicpmv2_6/multi_img-menu.png" alt="Menu" style="margin-bottom: 5px;">
|
||||
<img src="../assets/minicpmv2_6/multi_img-code.png" alt="Code" style="margin-bottom: 5px;">
|
||||
<img src="../assets/minicpmv2_6/ICL-Mem.png" alt="Mem" style="margin-bottom: 5px;">
|
||||
<img src="../assets/minicpmv2_6/multiling-medal.png" alt="medal" style="margin-bottom: 10px;">
|
||||
</div>
|
||||
<details>
|
||||
<summary>Click to view more cases.</summary>
|
||||
<div style="display: flex; flex-direction: column; align-items: center;">
|
||||
<img src="../assets/minicpmv2_6/ICL-elec.png" alt="elec" style="margin-bottom: 5px;">
|
||||
<img src="../assets/minicpmv2_6/multiling-olympic.png" alt="Menu" style="margin-bottom: 10px;">
|
||||
</div>
|
||||
</details>
|
||||
|
||||
We deploy MiniCPM-V 2.6 on end devices. The demo video is the raw screen recording on a iPad Pro without edition.
|
||||
|
||||
<table align="center">
|
||||
<p align="center">
|
||||
<img src="../assets/gif_cases/ai.gif" width=32%/>
|
||||
|
||||
<img src="../assets/gif_cases/beer.gif" width=32%/>
|
||||
</p>
|
||||
</table>
|
||||
|
||||
<table align="center">
|
||||
<p align="center">
|
||||
<img src="../assets/gif_cases/ticket.gif" width=32%/>
|
||||
|
||||
<img src="../assets/gif_cases/wfh.gif" width=32%/>
|
||||
</p>
|
||||
</table>
|
||||
|
||||
<table align="center">
|
||||
<p align="center">
|
||||
<video src="https://github.com/user-attachments/assets/21f4b818-ede1-4822-920e-91281725c830" width="360" /> </video>
|
||||
<!-- <video src="https://github.com/user-attachments/assets/c835f757-206b-4d9c-8e36-70d67b453628" width="360" /> </video> -->
|
||||
</p>
|
||||
</table>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### Multi-turn Conversation
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="../assets/airplane.jpeg" width="500px">
|
||||
</div>
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
||||
model = model.eval().cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
|
||||
|
||||
image = Image.open('./assets/airplane.jpeg').convert('RGB')
|
||||
|
||||
# First round chat
|
||||
question = "Tell me the model of this aircraft."
|
||||
msgs = [{'role': 'user', 'content': [image, question]}]
|
||||
|
||||
answer = model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
print(answer)
|
||||
|
||||
# Second round chat
|
||||
# pass history context of multi-turn conversation
|
||||
msgs.append({"role": "assistant", "content": [answer]})
|
||||
msgs.append({"role": "user", "content": ["Introduce something about Airbus A380."]})
|
||||
|
||||
answer = model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
print(answer)
|
||||
```
|
||||
|
||||
You could get the following output:
|
||||
|
||||
```
|
||||
"The aircraft in the image is an Airbus A380, which can be identified by its large size, double-deck structure, and the distinctive shape of its wings and engines. The A380 is a wide-body aircraft known for being the world's largest passenger airliner, designed for long-haul flights. It has four engines, which are characteristic of large commercial aircraft. The registration number on the aircraft can also provide specific information about the model if looked up in an aviation database."
|
||||
|
||||
"The Airbus A380 is a double-deck, wide-body, four-engine jet airliner made by Airbus. It is the world's largest passenger airliner and is known for its long-haul capabilities. The aircraft was developed to improve efficiency and comfort for passengers traveling over long distances. It has two full-length passenger decks, which can accommodate more passengers than a typical single-aisle airplane. The A380 has been operated by airlines such as Lufthansa, Singapore Airlines, and Emirates, among others. It is widely recognized for its unique design and significant impact on the aviation industry."
|
||||
```
|
||||
|
||||
#### Multi-image Understanding
|
||||
<details>
|
||||
<summary> Click to view Python example of MiniCPM-V 2.6 multi-image understanding </summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
||||
model = model.eval().cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
|
||||
|
||||
image1 = Image.open('image1.jpg').convert('RGB')
|
||||
image2 = Image.open('image2.jpg').convert('RGB')
|
||||
question = 'Compare image 1 and image 2, tell me about the differences between image 1 and image 2.'
|
||||
|
||||
msgs = [{'role': 'user', 'content': [image1, image2, question]}]
|
||||
|
||||
answer = model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
print(answer)
|
||||
```
|
||||
</details>
|
||||
|
||||
#### Few-shot In-Context-Learning
|
||||
|
||||
<details>
|
||||
<summary> Click to view Python example of MiniCPM-V 2.6 few-shot in-context-learning example </summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
||||
model = model.eval().cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
|
||||
|
||||
question = "production date"
|
||||
image1 = Image.open('example1.jpg').convert('RGB')
|
||||
answer1 = "2023.08.04"
|
||||
image2 = Image.open('example2.jpg').convert('RGB')
|
||||
answer2 = "2007.04.24"
|
||||
image_test = Image.open('test.jpg').convert('RGB')
|
||||
|
||||
msgs = [
|
||||
{'role': 'user', 'content': [image1, question]}, {'role': 'assistant', 'content': [answer1]},
|
||||
{'role': 'user', 'content': [image2, question]}, {'role': 'assistant', 'content': [answer2]},
|
||||
{'role': 'user', 'content': [image_test, question]}
|
||||
]
|
||||
|
||||
answer = model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
print(answer)
|
||||
```
|
||||
</details>
|
||||
|
||||
#### Video understanding
|
||||
<details>
|
||||
<summary> Click to view Python example of MiniCPM-V 2.6 video understanding </summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from decord import VideoReader, cpu # pip install decord
|
||||
|
||||
model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
||||
model = model.eval().cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True)
|
||||
|
||||
MAX_NUM_FRAMES=64 # if cuda OOM set a smaller number
|
||||
|
||||
def encode_video(video_path):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
video_path="video_test.mp4"
|
||||
frames = encode_video(video_path)
|
||||
question = "Describe the video"
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
# Set decode params for video
|
||||
params = {}
|
||||
params["use_image_id"] = False
|
||||
params["max_slice_nums"] = 2 # 如果cuda OOM且视频分辨率大于448*448可设为1
|
||||
|
||||
answer = model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer,
|
||||
**params
|
||||
)
|
||||
print(answer)
|
||||
```
|
||||
</details>
|
||||
@@ -1,6 +1,6 @@
|
||||
## OmniLMM-12B
|
||||
|
||||
> OmniLMM-12B is released at early time of this project. We recommond you to use our [recently released models](./README_en.md), for better performance and efficiency.
|
||||
> OmniLMM-12B is released at early time of this project. We recommond you to use our [recently released models](./README.md), for better performance and efficiency.
|
||||
|
||||
> Archieve at: 2024-05-19
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
<img src="../assets/minicpm-v25.png" width="60%"/>
|
||||
<img src="../assets/wechat-QR.jpeg" width="60%"/>
|
||||
|
||||
<p> 扫码加入「MiniCPM-V 交流群」 </p>
|
||||
<p> Scan the QR code to join the "MiniCPM-V Discussion Group" </p>
|
||||
<p> 扫码加入「MiniCPM-o 交流群」 </p>
|
||||
<p> Scan the QR code to join the "MiniCPM-o Discussion Group" </p>
|
||||
</div>
|
||||
|
||||
@@ -1,60 +1,56 @@
|
||||
# Evaluation
|
||||
|
||||
## opencompass
|
||||
## MiniCPM-o 2.6
|
||||
|
||||
### opencompass
|
||||
First, enter the `vlmevalkit` directory and install all dependencies:
|
||||
```bash
|
||||
cd vlmevalkit
|
||||
pip install -r requirements.txt
|
||||
pip install --upgrade pip
|
||||
pip install -e .
|
||||
wget https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=4377e0a7fe8ff8ffc4f7c9c6130c1dcd3874050ae4fc28b7ff1d35234fbca423
|
||||
wget https://download.pytorch.org/whl/cu118/torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=2e63d62e09d9b48b407d3e1b30eb8ae4e3abad6968e8d33093b60d0657542428
|
||||
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
pip install torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl
|
||||
pip install torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl
|
||||
pip install flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
```
|
||||
<br />
|
||||
|
||||
Then, run `script/run_inference.sh`, which receives three input parameters in sequence: `MODELNAME`, `DATALIST`, and `MODE`. `MODELNAME` represents the name of the model, `DATALIST` represents the datasets used for inference, and `MODE` represents evaluation mode:
|
||||
Then, run `scripts/run_inference.sh`, which receives two input parameters in sequence: `MODELNAME` and `DATALIST`. `MODELNAME` represents the name of the model, `DATALIST` represents the datasets used for inference:
|
||||
```bash
|
||||
chmod +x ./script/run_inference.sh
|
||||
./script/run_inference.sh $MODELNAME $DATALIST $MODE
|
||||
chmod +x ./scripts/run_inference.sh
|
||||
./scripts/run_inference.sh $MODELNAME $DATALIST
|
||||
```
|
||||
<br />
|
||||
|
||||
The three available choices for `MODELNAME` are listed in `vlmeval/config.py`:
|
||||
The five available choices for `MODELNAME` are listed in `vlmeval/config.py`:
|
||||
```bash
|
||||
ungrouped = {
|
||||
'MiniCPM-V':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5':partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
minicpm_series = {
|
||||
'MiniCPM-V': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5': partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
'MiniCPM-V-2_6': partial(MiniCPM_V_2_6, model_path='openbmb/MiniCPM-V-2_6'),
|
||||
'MiniCPM-o-2_6': partial(MiniCPM_o_2_6, model_path='openbmb/MiniCPM-o-2_6'),
|
||||
}
|
||||
```
|
||||
<br />
|
||||
|
||||
All available choices for `DATALIST` are listed in `vlmeval/utils/dataset_config.py`. While evaluating on a single dataset, call the dataset name directly without quotation marks; while evaluating on multiple datasets, separate the names of different datasets with spaces and add quotation marks at both ends:
|
||||
All available choices for `DATALIST` are listed in `vlmeval/utils/dataset_config.py`. While evaluating on multiple datasets at a time, separate the names of different datasets with spaces and add quotation marks at both ends:
|
||||
```bash
|
||||
$DATALIST="POPE ScienceQA_TEST ChartQA_TEST"
|
||||
$DATALIST="MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST"
|
||||
```
|
||||
<br />
|
||||
|
||||
While scoring on each benchmark directly, set `MODE=all`. If only inference results are required, set `MODE=infer`. In order to reproduce the results in the table displayed on the homepage (columns between MME and RealWorldQA), you need to run the script according to the following settings:
|
||||
When the benchmark requires GPT series model for scoring, please specify `OPENAI_API_BASE` and `OPENAI_API_KEY` in the `.env` file.
|
||||
In order to reproduce the results on OpenCompass benchmarks together with ChartQA and MME, which are displayed in the table on the homepage (columns between OCRBench and HallusionBench), you need to run the script according to the following settings:
|
||||
```bash
|
||||
# run on all 7 datasets
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 "MME MMBench_TEST_EN MMBench_TEST_CN MMMU_DEV_VAL MathVista_MINI LLaVABench RealWorldQA" all
|
||||
|
||||
# The following are instructions for running on a single dataset
|
||||
# MME
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MME all
|
||||
# MMBench_TEST_EN
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_EN all
|
||||
# MMBench_TEST_CN
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_CN all
|
||||
# MMMU_DEV_VAL
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MMMU_DEV_VAL all
|
||||
# MathVista_MINI
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MathVista_MINI all
|
||||
# LLaVABench
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 LLaVABench all
|
||||
# RealWorldQA
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 RealWorldQA all
|
||||
# Please note that we use different prompts for the perception and reasoning sets of MME. While evaluating on the reasoning subset, CoT is required, so you need to manually modify the judgment condition of the use_cot function in vlmeval/vlm/minicpm_v.py
|
||||
./scripts/run_inference.sh MiniCPM-o-2_6 "MMMU_DEV_VAL MathVista_MINI MMVet MMBench_TEST_EN_V11 MMBench_TEST_CN_V11 MMStar HallusionBench AI2D_TEST OCRBench ChartQA_TEST MME"
|
||||
```
|
||||
<br />
|
||||
|
||||
## vqadataset
|
||||
### vqadataset
|
||||
First, enter the `vqaeval` directory and install all dependencies. Then, create `downloads` subdirectory to store the downloaded dataset for all tasks:
|
||||
```bash
|
||||
cd vqaeval
|
||||
@@ -112,7 +108,8 @@ chmod +x ./shell/run_inference.sh
|
||||
```
|
||||
<br />
|
||||
|
||||
All optional parameters are listed in `eval_utils/getargs.py`. The meanings of some major parameters are listed as follows:
|
||||
All optional parameters are listed in `eval_utils/getargs.py`. The meanings of some major parameters are listed as follows.
|
||||
For `MiniCPM-o-2_6`, set `model_name` to `minicpmo26`:
|
||||
```bash
|
||||
# path to images and their corresponding questions
|
||||
# TextVQA
|
||||
@@ -174,4 +171,373 @@ For the DocVQATest task, in order to upload the inference results to the [offici
|
||||
```bash
|
||||
chmod +x ./shell/run_transform.sh
|
||||
./shell/run_transform.sh
|
||||
```
|
||||
```
|
||||
|
||||
<br />
|
||||
|
||||
## MiniCPM-V 2.6
|
||||
|
||||
<details>
|
||||
<summary>Expand</summary>
|
||||
|
||||
### opencompass
|
||||
First, enter the `vlmevalkit` directory and install all dependencies:
|
||||
```bash
|
||||
cd vlmevalkit
|
||||
pip install --upgrade pip
|
||||
pip install -e .
|
||||
wget https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=4377e0a7fe8ff8ffc4f7c9c6130c1dcd3874050ae4fc28b7ff1d35234fbca423
|
||||
wget https://download.pytorch.org/whl/cu118/torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=2e63d62e09d9b48b407d3e1b30eb8ae4e3abad6968e8d33093b60d0657542428
|
||||
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
pip install torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl
|
||||
pip install torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl
|
||||
pip install flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
```
|
||||
<br />
|
||||
|
||||
Then, run `scripts/run_inference.sh`, which receives three input parameters in sequence: `MODELNAME`, `DATALIST`, and `MODE`. `MODELNAME` represents the name of the model, `DATALIST` represents the datasets used for inference, and `MODE` represents evaluation mode:
|
||||
```bash
|
||||
chmod +x ./scripts/run_inference.sh
|
||||
./scripts/run_inference.sh $MODELNAME $DATALIST $MODE
|
||||
```
|
||||
<br />
|
||||
|
||||
The four available choices for `MODELNAME` are listed in `vlmeval/config.py`:
|
||||
```bash
|
||||
minicpm_series = {
|
||||
'MiniCPM-V': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5': partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
'MiniCPM-V-2_6': partial(MiniCPM_V_2_6, model_path='openbmb/MiniCPM-V-2_6'),
|
||||
}
|
||||
```
|
||||
<br />
|
||||
|
||||
All available choices for `DATALIST` are listed in `vlmeval/utils/dataset_config.py`. Separate the names of different datasets with spaces and add quotation marks at both ends:
|
||||
```bash
|
||||
$DATALIST="MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST"
|
||||
```
|
||||
<br />
|
||||
|
||||
While scoring on each benchmark directly, set `MODE=all`. If only inference results are required, set `MODE=infer`. In order to reproduce the results in the table displayed on the homepage (columns between MME and HallusionBench), you need to run the script according to the following settings:
|
||||
```bash
|
||||
# without CoT
|
||||
./scripts/run_inference.sh MiniCPM-V-2_6 "MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST" all
|
||||
./scripts/run_inference.sh MiniCPM-V-2_6 MME all
|
||||
# with CoT
|
||||
# While running the CoT version of MME, you need to modify the 'use_cot' function in vlmeval/vlm/minicpm_v.py and add MME to the branch that returns True.
|
||||
./scripts/run_inference/sh MiniCPM-V-2_6 "MMMU_DEV_VAL MMVet MMStar HallusionBench OCRBench" all
|
||||
./scripts/run_inference.sh MiniCPM-V-2_6 MME all
|
||||
```
|
||||
<br />
|
||||
|
||||
### vqadataset
|
||||
First, enter the `vqaeval` directory and install all dependencies. Then, create `downloads` subdirectory to store the downloaded dataset for all tasks:
|
||||
```bash
|
||||
cd vqaeval
|
||||
pip install -r requirements.txt
|
||||
mkdir downloads
|
||||
```
|
||||
<br />
|
||||
|
||||
Download the datasets from the following links and place it in the specified directories:
|
||||
###### TextVQA
|
||||
```bash
|
||||
cd downloads
|
||||
mkdir TextVQA && cd TextVQA
|
||||
wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
|
||||
unzip train_val_images.zip && rm train_val_images.zip
|
||||
mv train_val_images/train_images . && rm -rf train_val_images
|
||||
wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
|
||||
cd ../..
|
||||
```
|
||||
|
||||
###### DocVQA / DocVQATest
|
||||
|
||||
```bash
|
||||
cd downloads
|
||||
mkdir DocVQA && cd DocVQA && mkdir spdocvqa_images
|
||||
# Download Images and Annotations from Task 1 - Single Page Document Visual Question Answering at https://rrc.cvc.uab.es/?ch=17&com=downloads
|
||||
# Move the spdocvqa_images.tar.gz and spdocvqa_qas.zip to DocVQA directory
|
||||
tar -zxvf spdocvqa_images.tar.gz -C spdocvqa_images && rm spdocvqa_images.tar.gz
|
||||
unzip spdocvqa_qas.zip && rm spdocvqa_qas.zip
|
||||
cp spdocvqa_qas/val_v1.0_withQT.json . && cp spdocvqa_qas/test_v1.0.json . && rm -rf spdocvqa_qas
|
||||
cd ../..
|
||||
```
|
||||
<br />
|
||||
|
||||
The `downloads` directory should be organized according to the following structure:
|
||||
```bash
|
||||
downloads
|
||||
├── TextVQA
|
||||
│ ├── train_images
|
||||
│ │ ├── ...
|
||||
│ ├── TextVQA_0.5.1_val.json
|
||||
├── DocVQA
|
||||
│ ├── spdocvqa_images
|
||||
│ │ ├── ...
|
||||
│ ├── val_v1.0_withQT.json
|
||||
│ ├── test_v1.0.json
|
||||
```
|
||||
<br />
|
||||
|
||||
Modify the parameters in `shell/run_inference.sh` and run inference:
|
||||
|
||||
```bash
|
||||
chmod +x ./shell/run_inference.sh
|
||||
./shell/run_inference.sh
|
||||
```
|
||||
<br />
|
||||
|
||||
All optional parameters are listed in `eval_utils/getargs.py`. The meanings of some major parameters are listed as follows.
|
||||
For `MiniCPM-V-2_6`, set `model_name` to `minicpmv26`:
|
||||
```bash
|
||||
# path to images and their corresponding questions
|
||||
# TextVQA
|
||||
--textVQA_image_dir
|
||||
--textVQA_ann_path
|
||||
# DocVQA
|
||||
--docVQA_image_dir
|
||||
--docVQA_ann_path
|
||||
# DocVQATest
|
||||
--docVQATest_image_dir
|
||||
--docVQATest_ann_path
|
||||
|
||||
# whether to eval on certain task
|
||||
--eval_textVQA
|
||||
--eval_docVQA
|
||||
--eval_docVQATest
|
||||
--eval_all
|
||||
|
||||
# model name and model path
|
||||
--model_name
|
||||
--model_path
|
||||
# load model from ckpt
|
||||
--ckpt
|
||||
# the way the model processes input data, "interleave" represents interleaved image-text form, while "old" represents non-interleaved.
|
||||
--generate_method
|
||||
|
||||
--batchsize
|
||||
|
||||
# path to save the outputs
|
||||
--answer_path
|
||||
```
|
||||
<br />
|
||||
|
||||
While evaluating on different tasks, parameters need to be set as follows:
|
||||
###### TextVQA
|
||||
```bash
|
||||
--eval_textVQA
|
||||
--textVQA_image_dir ./downloads/TextVQA/train_images
|
||||
--textVQA_ann_path ./downloads/TextVQA/TextVQA_0.5.1_val.json
|
||||
```
|
||||
|
||||
###### DocVQA
|
||||
```bash
|
||||
--eval_docVQA
|
||||
--docVQA_image_dir ./downloads/DocVQA/spdocvqa_images
|
||||
--docVQA_ann_path ./downloads/DocVQA/val_v1.0_withQT.json
|
||||
```
|
||||
|
||||
###### DocVQATest
|
||||
```bash
|
||||
--eval_docVQATest
|
||||
--docVQATest_image_dir ./downloads/DocVQA/spdocvqa_images
|
||||
--docVQATest_ann_path ./downloads/DocVQA/test_v1.0.json
|
||||
```
|
||||
|
||||
<br />
|
||||
|
||||
For the DocVQATest task, in order to upload the inference results to the [official website](https://rrc.cvc.uab.es/?ch=17) for evaluation, run `shell/run_transform.sh` for format transformation after inference. `input_file_path` represents the path to the original output json, `output_file_path` represents the path to the transformed json:
|
||||
```bash
|
||||
chmod +x ./shell/run_transform.sh
|
||||
./shell/run_transform.sh
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<br />
|
||||
|
||||
## MiniCPM-Llama3-V-2_5
|
||||
|
||||
<details>
|
||||
<summary>Expand</summary>
|
||||
|
||||
### opencompass
|
||||
First, enter the `vlmevalkit` directory and install all dependencies:
|
||||
```bash
|
||||
cd vlmevalkit
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
<br />
|
||||
|
||||
Then, run `scripts/run_inference.sh`, which receives three input parameters in sequence: `MODELNAME`, `DATALIST`, and `MODE`. `MODELNAME` represents the name of the model, `DATALIST` represents the datasets used for inference, and `MODE` represents evaluation mode:
|
||||
```bash
|
||||
chmod +x ./scripts/run_inference.sh
|
||||
./scripts/run_inference.sh $MODELNAME $DATALIST $MODE
|
||||
```
|
||||
<br />
|
||||
|
||||
The three available choices for `MODELNAME` are listed in `vlmeval/config.py`:
|
||||
```bash
|
||||
ungrouped = {
|
||||
'MiniCPM-V':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5':partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
}
|
||||
```
|
||||
<br />
|
||||
|
||||
All available choices for `DATALIST` are listed in `vlmeval/utils/dataset_config.py`. While evaluating on a single dataset, call the dataset name directly without quotation marks; while evaluating on multiple datasets, separate the names of different datasets with spaces and add quotation marks at both ends:
|
||||
```bash
|
||||
$DATALIST="POPE ScienceQA_TEST ChartQA_TEST"
|
||||
```
|
||||
<br />
|
||||
|
||||
While scoring on each benchmark directly, set `MODE=all`. If only inference results are required, set `MODE=infer`. In order to reproduce the results in the table displayed on the homepage (columns between MME and RealWorldQA), you need to run the script according to the following settings:
|
||||
```bash
|
||||
# run on all 7 datasets
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 "MME MMBench_TEST_EN MMBench_TEST_CN MMMU_DEV_VAL MathVista_MINI LLaVABench RealWorldQA" all
|
||||
|
||||
# The following are instructions for running on a single dataset
|
||||
# MME
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MME all
|
||||
# MMBench_TEST_EN
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_EN all
|
||||
# MMBench_TEST_CN
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_CN all
|
||||
# MMMU_DEV_VAL
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMMU_DEV_VAL all
|
||||
# MathVista_MINI
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MathVista_MINI all
|
||||
# LLaVABench
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 LLaVABench all
|
||||
# RealWorldQA
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 RealWorldQA all
|
||||
```
|
||||
<br />
|
||||
|
||||
### vqadataset
|
||||
First, enter the `vqaeval` directory and install all dependencies. Then, create `downloads` subdirectory to store the downloaded dataset for all tasks:
|
||||
```bash
|
||||
cd vqaeval
|
||||
pip install -r requirements.txt
|
||||
mkdir downloads
|
||||
```
|
||||
<br />
|
||||
|
||||
Download the datasets from the following links and place it in the specified directories:
|
||||
###### TextVQA
|
||||
```bash
|
||||
cd downloads
|
||||
mkdir TextVQA && cd TextVQA
|
||||
wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
|
||||
unzip train_val_images.zip && rm train_val_images.zip
|
||||
mv train_val_images/train_images . && rm -rf train_val_images
|
||||
wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
|
||||
cd ../..
|
||||
```
|
||||
|
||||
###### DocVQA / DocVQATest
|
||||
|
||||
```bash
|
||||
cd downloads
|
||||
mkdir DocVQA && cd DocVQA && mkdir spdocvqa_images
|
||||
# Download Images and Annotations from Task 1 - Single Page Document Visual Question Answering at https://rrc.cvc.uab.es/?ch=17&com=downloads
|
||||
# Move the spdocvqa_images.tar.gz and spdocvqa_qas.zip to DocVQA directory
|
||||
tar -zxvf spdocvqa_images.tar.gz -C spdocvqa_images && rm spdocvqa_images.tar.gz
|
||||
unzip spdocvqa_qas.zip && rm spdocvqa_qas.zip
|
||||
cp spdocvqa_qas/val_v1.0_withQT.json . && cp spdocvqa_qas/test_v1.0.json . && rm -rf spdocvqa_qas
|
||||
cd ../..
|
||||
```
|
||||
<br />
|
||||
|
||||
The `downloads` directory should be organized according to the following structure:
|
||||
```bash
|
||||
downloads
|
||||
├── TextVQA
|
||||
│ ├── train_images
|
||||
│ │ ├── ...
|
||||
│ ├── TextVQA_0.5.1_val.json
|
||||
├── DocVQA
|
||||
│ ├── spdocvqa_images
|
||||
│ │ ├── ...
|
||||
│ ├── val_v1.0_withQT.json
|
||||
│ ├── test_v1.0.json
|
||||
```
|
||||
<br />
|
||||
|
||||
Modify the parameters in `shell/run_inference.sh` and run inference:
|
||||
|
||||
```bash
|
||||
chmod +x ./shell/run_inference.sh
|
||||
./shell/run_inference.sh
|
||||
```
|
||||
<br />
|
||||
|
||||
All optional parameters are listed in `eval_utils/getargs.py`. The meanings of some major parameters are listed as follows.
|
||||
For `MiniCPM-Llama3-V-2_5`, set `model_name` to `minicpmv`:
|
||||
```bash
|
||||
# path to images and their corresponding questions
|
||||
# TextVQA
|
||||
--textVQA_image_dir
|
||||
--textVQA_ann_path
|
||||
# DocVQA
|
||||
--docVQA_image_dir
|
||||
--docVQA_ann_path
|
||||
# DocVQATest
|
||||
--docVQATest_image_dir
|
||||
--docVQATest_ann_path
|
||||
|
||||
# whether to eval on certain task
|
||||
--eval_textVQA
|
||||
--eval_docVQA
|
||||
--eval_docVQATest
|
||||
--eval_all
|
||||
|
||||
# model name and model path
|
||||
--model_name
|
||||
--model_path
|
||||
# load model from ckpt
|
||||
--ckpt
|
||||
# the way the model processes input data, "interleave" represents interleaved image-text form, while "old" represents non-interleaved.
|
||||
--generate_method
|
||||
|
||||
--batchsize
|
||||
|
||||
# path to save the outputs
|
||||
--answer_path
|
||||
```
|
||||
<br />
|
||||
|
||||
While evaluating on different tasks, parameters need to be set as follows:
|
||||
###### TextVQA
|
||||
```bash
|
||||
--eval_textVQA
|
||||
--textVQA_image_dir ./downloads/TextVQA/train_images
|
||||
--textVQA_ann_path ./downloads/TextVQA/TextVQA_0.5.1_val.json
|
||||
```
|
||||
|
||||
###### DocVQA
|
||||
```bash
|
||||
--eval_docVQA
|
||||
--docVQA_image_dir ./downloads/DocVQA/spdocvqa_images
|
||||
--docVQA_ann_path ./downloads/DocVQA/val_v1.0_withQT.json
|
||||
```
|
||||
|
||||
###### DocVQATest
|
||||
```bash
|
||||
--eval_docVQATest
|
||||
--docVQATest_image_dir ./downloads/DocVQA/spdocvqa_images
|
||||
--docVQATest_ann_path ./downloads/DocVQA/test_v1.0.json
|
||||
```
|
||||
|
||||
<br />
|
||||
|
||||
For the DocVQATest task, in order to upload the inference results to the [official website](https://rrc.cvc.uab.es/?ch=17) for evaluation, run `shell/run_transform.sh` for format transformation after inference. `input_file_path` represents the path to the original output json, `output_file_path` represents the path to the transformed json:
|
||||
```bash
|
||||
chmod +x ./shell/run_transform.sh
|
||||
./shell/run_transform.sh
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -1,61 +1,57 @@
|
||||
# Evaluation
|
||||
|
||||
## opencompass
|
||||
## MiniCPM-o 2.6
|
||||
|
||||
### opencompass
|
||||
首先,进入 `vlmevalkit` 目录下,安装必要的依赖:
|
||||
```bash
|
||||
cd vlmevalkit
|
||||
pip install -r requirements.txt
|
||||
pip install --upgrade pip
|
||||
pip install -e .
|
||||
wget https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=4377e0a7fe8ff8ffc4f7c9c6130c1dcd3874050ae4fc28b7ff1d35234fbca423
|
||||
wget https://download.pytorch.org/whl/cu118/torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=2e63d62e09d9b48b407d3e1b30eb8ae4e3abad6968e8d33093b60d0657542428
|
||||
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
pip install torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl
|
||||
pip install torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl
|
||||
pip install flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
rm *.whl
|
||||
```
|
||||
<br />
|
||||
|
||||
然后,运行 `script/run_inference.sh`,该脚本依次接收三个输入参数:`MODELNAME`, `DATALIST`, `MODE`。`MODELNAME` 为模型名称,`DATALIST` 为目标数据集,`MODE` 为评测模式。
|
||||
然后,运行 `scripts/run_inference.sh`,该脚本依次接收两个输入参数:`MODELNAME`, `DATALIST`。其中,`MODELNAME` 为模型名称,`DATALIST` 为目标数据集。
|
||||
```bash
|
||||
chmod +x ./script/run_inference.sh
|
||||
./script/run_inference.sh $MODELNAME $DATALIST $MODE
|
||||
chmod +x ./scripts/run_inference.sh
|
||||
./scripts/run_inference.sh $MODELNAME $DATALIST
|
||||
```
|
||||
<br />
|
||||
|
||||
`MODELNAME` 有三种选择,位于 `vlmeval/config.py` 中:
|
||||
`MODELNAME` 有五种选择,位于 `vlmeval/config.py` 中:
|
||||
```bash
|
||||
ungrouped = {
|
||||
'MiniCPM-V':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5':partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
minicpm_series = {
|
||||
'MiniCPM-V': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5': partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
'MiniCPM-V-2_6': partial(MiniCPM_V_2_6, model_path='openbmb/MiniCPM-V-2_6'),
|
||||
'MiniCPM-o-2_6': partial(MiniCPM_o_2_6, model_path='openbmb/MiniCPM-o-2_6'),
|
||||
}
|
||||
```
|
||||
<br />
|
||||
|
||||
可选的所有 `DATALIST` 位于 `vlmeval/utils/dataset_config.py` 中,评测单个数据集时,直接调用数据集名称,不加引号;评测多个数据集时,将不同数据集名称以空格隔开,两端加引号:
|
||||
可选的所有 `DATALIST` 位于 `vlmeval/utils/dataset_config.py` 中。一次评测多个数据集时,将不同数据集名称以空格隔开,两端加引号:
|
||||
```bash
|
||||
$DATALIST="POPE ScienceQA_TEST ChartQA_TEST"
|
||||
$DATALIST="MMMU_DEV_VAL MathVista_MINI MMVet MMBench_TEST_EN_V11 MMBench_TEST_CN_V11 MMStar HallusionBench AI2D_TEST"
|
||||
```
|
||||
<br />
|
||||
|
||||
直接对各 benchmark 进行评分时,设置 `MODE=all`。如果仅需要推理结果,则设置 `MODE=infer`
|
||||
为了复现出首页展示的表格中的各项结果(MME 到 RealWorldQA 之间的列),需要按照如下设置运行:
|
||||
当评测的 benchmark 需要 GPT 系列模型进行评分时,请在 `.env` 文件中预先指定 `OPENAI_API_BASE` 和 `OPENAI_API_KEY`。
|
||||
为了复现出首页展示的表格中 OpenCompass 对应的各项数据集以及 ChartQA 和 MME 上的结果(OCRBench 到 HallusionBench 之间的列),需要按照如下设置运行:
|
||||
```bash
|
||||
# 一次性运行 7 个数据集
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 "MME MMBench_TEST_EN MMBench_TEST_CN MMMU_DEV_VAL MathVista_MINI LLaVABench RealWorldQA" all
|
||||
|
||||
# 以下是单独运行 1 个数据集的指令
|
||||
# MME
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MME all
|
||||
# MMBench_TEST_EN
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_EN all
|
||||
# MMBench_TEST_CN
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_CN all
|
||||
# MMMU_DEV_VAL
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MMMU_DEV_VAL all
|
||||
# MathVista_MINI
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 MathVista_MINI all
|
||||
# LLaVABench
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 LLaVABench all
|
||||
# RealWorldQA
|
||||
./script/run_inference.sh MiniCPM-Llama3-V-2_5 RealWorldQA all
|
||||
# 请注意,对于 MME 的 perception 和 reasoning 集,我们采取了不同的 prompt 方式。评测 reasoning 子集时,需要使用 CoT,因此需要手动到 vlmeval/vlm/minicpm_v.py 中修改 use_cot 函数的判断条件
|
||||
./scripts/run_inference.sh MiniCPM-o-2_6 "MMMU_DEV_VAL MathVista_MINI MMVet MMBench_TEST_EN_V11 MMBench_TEST_CN_V11 MMStar HallusionBench AI2D_TEST OCRBench ChartQA_TEST MME"
|
||||
```
|
||||
<br />
|
||||
|
||||
## vqadataset
|
||||
### vqadataset
|
||||
首先,进入 `vqaeval` 目录下,安装必要的依赖,并创建 `downloads` 子目录,用于存储下载的数据集:
|
||||
```bash
|
||||
cd vqaeval
|
||||
@@ -112,7 +108,8 @@ chmod +x ./shell/run_inference.sh
|
||||
```
|
||||
<br />
|
||||
|
||||
可以传入的参数位于 `eval_utils/getargs.py` 中,各主要参数的含义如下:
|
||||
可以传入的参数位于 `eval_utils/getargs.py` 中,各主要参数的含义如下。
|
||||
对于 `MiniCPM-o-2_6`,需要将 `model_name`设置为 `minicpmo26`:
|
||||
```bash
|
||||
# 指定 TextVQA 评测所有图片和问题的路径
|
||||
--textVQA_image_dir
|
||||
@@ -172,4 +169,369 @@ chmod +x ./shell/run_inference.sh
|
||||
```bash
|
||||
chmod +x ./shell/run_transform.sh
|
||||
./shell/run_transform.sh
|
||||
```
|
||||
```
|
||||
|
||||
<br />
|
||||
|
||||
## MiniCPM-V 2.6
|
||||
|
||||
<details>
|
||||
<summary>展开</summary>
|
||||
|
||||
### opencompass
|
||||
首先,进入 `vlmevalkit` 目录下,安装必要的依赖:
|
||||
```bash
|
||||
cd vlmevalkit
|
||||
pip install --upgrade pip
|
||||
pip install -e .
|
||||
wget https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=4377e0a7fe8ff8ffc4f7c9c6130c1dcd3874050ae4fc28b7ff1d35234fbca423
|
||||
wget https://download.pytorch.org/whl/cu118/torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=2e63d62e09d9b48b407d3e1b30eb8ae4e3abad6968e8d33093b60d0657542428
|
||||
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
pip install torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl
|
||||
pip install torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl
|
||||
pip install flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
rm *.whl
|
||||
```
|
||||
<br />
|
||||
|
||||
然后,运行 `scripts/run_inference.sh`,该脚本依次接收三个输入参数:`MODELNAME`, `DATALIST`, `MODE`。`MODELNAME` 为模型名称,`DATALIST` 为目标数据集,`MODE` 为评测模式。
|
||||
```bash
|
||||
chmod +x ./scripts/run_inference.sh
|
||||
./scripts/run_inference.sh $MODELNAME $DATALIST $MODE
|
||||
```
|
||||
<br />
|
||||
|
||||
`MODELNAME` 有四种选择,位于 `vlmeval/config.py` 中:
|
||||
```bash
|
||||
minicpm_series = {
|
||||
'MiniCPM-V': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5': partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
'MiniCPM-V-2_6': partial(MiniCPM_V_2_6, model_path='openbmb/MiniCPM-V-2_6'),
|
||||
}
|
||||
```
|
||||
<br />
|
||||
|
||||
可选的所有 `DATALIST` 位于 `vlmeval/utils/dataset_config.py` 中。将不同数据集名称以空格隔开,两端加引号:
|
||||
```bash
|
||||
$DATALIST="MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST"
|
||||
```
|
||||
<br />
|
||||
|
||||
直接对各 benchmark 进行评分时,设置 `MODE=all`。如果仅需要推理结果,则设置 `MODE=infer`。
|
||||
为了复现出首页展示的表格中的各项结果(MME 到 HallusionBench 之间的列),需要按照如下设置运行:
|
||||
```bash
|
||||
# without CoT
|
||||
./scripts/run_inference.sh MiniCPM-V-2_6 "MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST" all
|
||||
./scripts/run_inference.sh MiniCPM-V-2_6 MME all
|
||||
# with CoT,运行 CoT 版本的 MME 时,需要改写 vlmeval/vlm/minicpm_v.py 中的 'use_cot' 函数,将 MME 添加到 return True 的分支中
|
||||
./scripts/run_inference/sh MiniCPM-V-2_6 "MMMU_DEV_VAL MMVet MMStar HallusionBench OCRBench" all
|
||||
./scripts/run_inference.sh MiniCPM-V-2_6 MME all
|
||||
```
|
||||
<br />
|
||||
|
||||
### vqadataset
|
||||
首先,进入 `vqaeval` 目录下,安装必要的依赖,并创建 `downloads` 子目录,用于存储下载的数据集:
|
||||
```bash
|
||||
cd vqaeval
|
||||
pip install -r requirements.txt
|
||||
mkdir downloads
|
||||
```
|
||||
<br />
|
||||
|
||||
然后,从下列各地址下载数据集并置于指定目录下:
|
||||
###### TextVQA
|
||||
```bash
|
||||
cd downloads
|
||||
mkdir TextVQA && cd TextVQA
|
||||
wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
|
||||
unzip train_val_images.zip && rm train_val_images.zip
|
||||
mv train_val_images/train_images . && rm -rf train_val_images
|
||||
wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
|
||||
cd ../..
|
||||
```
|
||||
|
||||
###### DocVQA / DocVQATest
|
||||
```bash
|
||||
cd downloads
|
||||
mkdir DocVQA && cd DocVQA && mkdir spdocvqa_images
|
||||
# 在 https://rrc.cvc.uab.es/?ch=17&com=downloads 下载 Task 1 - Single Page Document Visual Question Answering 下的 Images 和 Annotations
|
||||
# 将下载得到的 spdocvqa_images.tar.gz 以及 spdocvqa_qas.zip 置于 DocVQA 目录下
|
||||
tar -zxvf spdocvqa_images.tar.gz -C spdocvqa_images && rm spdocvqa_images.tar.gz
|
||||
unzip spdocvqa_qas.zip && rm spdocvqa_qas.zip
|
||||
cp spdocvqa_qas/val_v1.0_withQT.json . && cp spdocvqa_qas/test_v1.0.json . && rm -rf spdocvqa_qas
|
||||
cd ../..
|
||||
```
|
||||
<br />
|
||||
|
||||
`downloads` 目录应当按照下列结构组织:
|
||||
```bash
|
||||
downloads
|
||||
├── TextVQA
|
||||
│ ├── train_images
|
||||
│ │ ├── ...
|
||||
│ ├── TextVQA_0.5.1_val.json
|
||||
├── DocVQA
|
||||
│ ├── spdocvqa_images
|
||||
│ │ ├── ...
|
||||
│ ├── val_v1.0_withQT.json
|
||||
│ ├── test_v1.0.json
|
||||
```
|
||||
<br />
|
||||
|
||||
准备好相应的数据集之后,修改 `shell/run_inference.sh` 的参数,运行推理:
|
||||
|
||||
```bash
|
||||
chmod +x ./shell/run_inference.sh
|
||||
./shell/run_inference.sh
|
||||
```
|
||||
<br />
|
||||
|
||||
可以传入的参数位于 `eval_utils/getargs.py` 中,各主要参数的含义如下。
|
||||
对于 `MiniCPM-V-2_6`,需要将 `model_name`设置为 `minicpmv26`:
|
||||
```bash
|
||||
# 指定 TextVQA 评测所有图片和问题的路径
|
||||
--textVQA_image_dir
|
||||
--textVQA_ann_path
|
||||
# 指定 DocVQA 评测所有图片和问题的路径
|
||||
--docVQA_image_dir
|
||||
--docVQA_ann_path
|
||||
# 指定 DocVQATest 评测所有图片和问题的路径
|
||||
--docVQATest_image_dir
|
||||
--docVQATest_ann_path
|
||||
|
||||
# 决定是否评测某个任务,eval_all 设置为 True 表示所有任务都评测
|
||||
--eval_textVQA
|
||||
--eval_docVQA
|
||||
--eval_docVQATest
|
||||
--eval_all
|
||||
|
||||
# 模型名称、模型路径(从指定路径加载模型)
|
||||
--model_name
|
||||
--model_path
|
||||
# 从 checkpoint 加载模型
|
||||
--ckpt
|
||||
# 模型处理输入数据的方式,interleave 表示图文交错式,old 表示非交错式
|
||||
--generate_method
|
||||
# 推理时的批处理规模,建议推理时设置为 1
|
||||
--batchsize
|
||||
|
||||
# 输出内容保存的路径
|
||||
--answer_path
|
||||
```
|
||||
<br />
|
||||
|
||||
评测三个任务需要设置的参数如下:
|
||||
###### TextVQA
|
||||
```bash
|
||||
--eval_textVQA
|
||||
--textVQA_image_dir ./downloads/TextVQA/train_images
|
||||
--textVQA_ann_path ./downloads/TextVQA/TextVQA_0.5.1_val.json
|
||||
```
|
||||
|
||||
###### DocVQA
|
||||
```bash
|
||||
--eval_docVQA
|
||||
--docVQA_image_dir ./downloads/DocVQA/spdocvqa_images
|
||||
--docVQA_ann_path ./downloads/DocVQA/val_v1.0_withQT.json
|
||||
```
|
||||
|
||||
###### DocVQATest
|
||||
```bash
|
||||
--eval_docVQATest
|
||||
--docVQATest_image_dir ./downloads/DocVQA/spdocvqa_images
|
||||
--docVQATest_ann_path ./downloads/DocVQA/test_v1.0.json
|
||||
```
|
||||
<br />
|
||||
|
||||
对于 DocVQATest 任务,为了将推理结果上传到[官方网站](https://rrc.cvc.uab.es/?ch=17)进行评测,还需要运行 `shell/run_transform.sh` 进行格式转换。其中,`input_file_path` 对应原始输出的 json 的路径,`output_file_path` 为自定义的转换后的 json 的路径:
|
||||
```bash
|
||||
chmod +x ./shell/run_transform.sh
|
||||
./shell/run_transform.sh
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<br />
|
||||
|
||||
## MiniCPM-Llama3-V-2_5
|
||||
|
||||
<details>
|
||||
<summary>展开</summary>
|
||||
|
||||
### opencompass
|
||||
首先,进入 `vlmevalkit` 目录下,安装必要的依赖:
|
||||
```bash
|
||||
cd vlmevalkit
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
<br />
|
||||
|
||||
然后,运行 `scripts/run_inference.sh`,该脚本依次接收三个输入参数:`MODELNAME`, `DATALIST`, `MODE`。`MODELNAME` 为模型名称,`DATALIST` 为目标数据集,`MODE` 为评测模式。
|
||||
```bash
|
||||
chmod +x ./scripts/run_inference.sh
|
||||
./scripts/run_inference.sh $MODELNAME $DATALIST $MODE
|
||||
```
|
||||
<br />
|
||||
|
||||
`MODELNAME` 有三种选择,位于 `vlmeval/config.py` 中:
|
||||
```bash
|
||||
ungrouped = {
|
||||
'MiniCPM-V':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5':partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
}
|
||||
```
|
||||
<br />
|
||||
|
||||
可选的所有 `DATALIST` 位于 `vlmeval/utils/dataset_config.py` 中,评测单个数据集时,直接调用数据集名称,不加引号;评测多个数据集时,将不同数据集名称以空格隔开,两端加引号:
|
||||
```bash
|
||||
$DATALIST="POPE ScienceQA_TEST ChartQA_TEST"
|
||||
```
|
||||
<br />
|
||||
|
||||
直接对各 benchmark 进行评分时,设置 `MODE=all`。如果仅需要推理结果,则设置 `MODE=infer`
|
||||
为了复现出首页展示的表格中的各项结果(MME 到 RealWorldQA 之间的列),需要按照如下设置运行:
|
||||
```bash
|
||||
# 一次性运行 7 个数据集
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 "MME MMBench_TEST_EN MMBench_TEST_CN MMMU_DEV_VAL MathVista_MINI LLaVABench RealWorldQA" all
|
||||
|
||||
# 以下是单独运行 1 个数据集的指令
|
||||
# MME
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MME all
|
||||
# MMBench_TEST_EN
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_EN all
|
||||
# MMBench_TEST_CN
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_CN all
|
||||
# MMMU_DEV_VAL
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMMU_DEV_VAL all
|
||||
# MathVista_MINI
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MathVista_MINI all
|
||||
# LLaVABench
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 LLaVABench all
|
||||
# RealWorldQA
|
||||
./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 RealWorldQA all
|
||||
```
|
||||
<br />
|
||||
|
||||
### vqadataset
|
||||
首先,进入 `vqaeval` 目录下,安装必要的依赖,并创建 `downloads` 子目录,用于存储下载的数据集:
|
||||
```bash
|
||||
cd vqaeval
|
||||
pip install -r requirements.txt
|
||||
mkdir downloads
|
||||
```
|
||||
<br />
|
||||
|
||||
然后,从下列各地址下载数据集并置于指定目录下:
|
||||
###### TextVQA
|
||||
```bash
|
||||
cd downloads
|
||||
mkdir TextVQA && cd TextVQA
|
||||
wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
|
||||
unzip train_val_images.zip && rm train_val_images.zip
|
||||
mv train_val_images/train_images . && rm -rf train_val_images
|
||||
wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
|
||||
cd ../..
|
||||
```
|
||||
|
||||
###### DocVQA / DocVQATest
|
||||
```bash
|
||||
cd downloads
|
||||
mkdir DocVQA && cd DocVQA && mkdir spdocvqa_images
|
||||
# 在 https://rrc.cvc.uab.es/?ch=17&com=downloads 下载 Task 1 - Single Page Document Visual Question Answering 下的 Images 和 Annotations
|
||||
# 将下载得到的 spdocvqa_images.tar.gz 以及 spdocvqa_qas.zip 置于 DocVQA 目录下
|
||||
tar -zxvf spdocvqa_images.tar.gz -C spdocvqa_images && rm spdocvqa_images.tar.gz
|
||||
unzip spdocvqa_qas.zip && rm spdocvqa_qas.zip
|
||||
cp spdocvqa_qas/val_v1.0_withQT.json . && cp spdocvqa_qas/test_v1.0.json . && rm -rf spdocvqa_qas
|
||||
cd ../..
|
||||
```
|
||||
<br />
|
||||
|
||||
`downloads` 目录应当按照下列结构组织:
|
||||
```bash
|
||||
downloads
|
||||
├── TextVQA
|
||||
│ ├── train_images
|
||||
│ │ ├── ...
|
||||
│ ├── TextVQA_0.5.1_val.json
|
||||
├── DocVQA
|
||||
│ ├── spdocvqa_images
|
||||
│ │ ├── ...
|
||||
│ ├── val_v1.0_withQT.json
|
||||
│ ├── test_v1.0.json
|
||||
```
|
||||
<br />
|
||||
|
||||
准备好相应的数据集之后,修改 `shell/run_inference.sh` 的参数,运行推理:
|
||||
|
||||
```bash
|
||||
chmod +x ./shell/run_inference.sh
|
||||
./shell/run_inference.sh
|
||||
```
|
||||
<br />
|
||||
|
||||
可以传入的参数位于 `eval_utils/getargs.py` 中,各主要参数的含义如下。
|
||||
对于 `MiniCPM-Llama3-V-2_5`,需要将 `model_name` 设置为 `minicpmv`:
|
||||
```bash
|
||||
# 指定 TextVQA 评测所有图片和问题的路径
|
||||
--textVQA_image_dir
|
||||
--textVQA_ann_path
|
||||
# 指定 DocVQA 评测所有图片和问题的路径
|
||||
--docVQA_image_dir
|
||||
--docVQA_ann_path
|
||||
# 指定 DocVQATest 评测所有图片和问题的路径
|
||||
--docVQATest_image_dir
|
||||
--docVQATest_ann_path
|
||||
|
||||
# 决定是否评测某个任务,eval_all 设置为 True 表示所有任务都评测
|
||||
--eval_textVQA
|
||||
--eval_docVQA
|
||||
--eval_docVQATest
|
||||
--eval_all
|
||||
|
||||
# 模型名称、模型路径(从指定路径加载模型)
|
||||
--model_name
|
||||
--model_path
|
||||
# 从 checkpoint 加载模型
|
||||
--ckpt
|
||||
# 模型处理输入数据的方式,interleave 表示图文交错式,old 表示非交错式
|
||||
--generate_method
|
||||
# 推理时的批处理规模,建议推理时设置为 1
|
||||
--batchsize
|
||||
|
||||
# 输出内容保存的路径
|
||||
--answer_path
|
||||
```
|
||||
<br />
|
||||
|
||||
评测三个任务需要设置的参数如下:
|
||||
###### TextVQA
|
||||
```bash
|
||||
--eval_textVQA
|
||||
--textVQA_image_dir ./downloads/TextVQA/train_images
|
||||
--textVQA_ann_path ./downloads/TextVQA/TextVQA_0.5.1_val.json
|
||||
```
|
||||
|
||||
###### DocVQA
|
||||
```bash
|
||||
--eval_docVQA
|
||||
--docVQA_image_dir ./downloads/DocVQA/spdocvqa_images
|
||||
--docVQA_ann_path ./downloads/DocVQA/val_v1.0_withQT.json
|
||||
```
|
||||
|
||||
###### DocVQATest
|
||||
```bash
|
||||
--eval_docVQATest
|
||||
--docVQATest_image_dir ./downloads/DocVQA/spdocvqa_images
|
||||
--docVQATest_ann_path ./downloads/DocVQA/test_v1.0.json
|
||||
```
|
||||
<br />
|
||||
|
||||
对于 DocVQATest 任务,为了将推理结果上传到[官方网站](https://rrc.cvc.uab.es/?ch=17)进行评测,还需要运行 `shell/run_transform.sh` 进行格式转换。其中,`input_file_path` 对应原始输出的 json 的路径,`output_file_path` 为自定义的转换后的 json 的路径:
|
||||
```bash
|
||||
chmod +x ./shell/run_transform.sh
|
||||
./shell/run_transform.sh
|
||||
```
|
||||
|
||||
</details>
|
||||
28
eval_mm/vlmevalkit/.env
Normal file
@@ -0,0 +1,28 @@
|
||||
# .env 文件,将其放置在 $VLMEvalKit 下
|
||||
# 专有 VLMs 的 API 密钥
|
||||
# QwenVL APIs
|
||||
DASHSCOPE_API_KEY=
|
||||
# Gemini w. Google Cloud Backends
|
||||
GOOGLE_API_KEY=
|
||||
# OpenAI API
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_API_BASE=
|
||||
# StepAI API
|
||||
STEPAI_API_KEY=
|
||||
# REKA API
|
||||
REKA_API_KEY=
|
||||
# GLMV API
|
||||
GLMV_API_KEY=
|
||||
# CongRong API
|
||||
CW_API_BASE=
|
||||
CW_API_KEY=
|
||||
# SenseChat-V API
|
||||
SENSECHAT_AK=
|
||||
SENSECHAT_SK=
|
||||
# Hunyuan-Vision API
|
||||
HUNYUAN_SECRET_KEY=
|
||||
HUNYUAN_SECRET_ID=
|
||||
# LMDeploy API
|
||||
LMDEPLOY_API_BASE=
|
||||
# 你可以设置一个评估时代理,评估阶段产生的 API 调用将通过这个代理进行
|
||||
EVAL_PROXY=
|
||||
@@ -1,33 +1,30 @@
|
||||
einops
|
||||
gradio==4.15.0
|
||||
decord; platform_machine != 'arm64'
|
||||
eva-decord; platform_machine == 'arm64'
|
||||
gradio
|
||||
huggingface_hub
|
||||
imageio
|
||||
matplotlib
|
||||
numpy>=1.23.4
|
||||
numpy
|
||||
omegaconf
|
||||
openai==1.3.5
|
||||
openai
|
||||
opencv-python>=4.4.0.46
|
||||
openpyxl
|
||||
pandas>=1.5.3
|
||||
pandas
|
||||
pillow
|
||||
portalocker
|
||||
protobuf
|
||||
pycocoevalcap
|
||||
python-dotenv
|
||||
requests
|
||||
rich
|
||||
seaborn
|
||||
sentencepiece
|
||||
setuptools
|
||||
sty
|
||||
tabulate
|
||||
tiktoken
|
||||
timeout-decorator
|
||||
torch
|
||||
tqdm
|
||||
typing_extensions==4.7.1
|
||||
transformers
|
||||
typing_extensions
|
||||
validators
|
||||
visual_genome
|
||||
xlsxwriter
|
||||
Pillow==10.1.0
|
||||
sentencepiece==0.1.99
|
||||
transformers==4.40.0
|
||||
torch==1.13.1
|
||||
torchvision
|
||||
|
||||
11
eval_mm/vlmevalkit/requirements/docs.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
docutils==0.18.1
|
||||
modelindex
|
||||
myst-parser
|
||||
-e git+https://github.com/open-compass/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||
sphinx==6.1.3
|
||||
sphinx-copybutton
|
||||
sphinx-design
|
||||
sphinx-notfound-page
|
||||
sphinx-tabs
|
||||
sphinxcontrib-jquery
|
||||
tabulate
|
||||
@@ -1,147 +1,422 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vlmeval.smp import *
|
||||
from vlmeval.evaluate import *
|
||||
from vlmeval.inference import infer_data_job
|
||||
|
||||
from vlmeval.config import supported_VLM
|
||||
from vlmeval.utils import dataset_URLs, DATASET_TYPE, abbr2full, MMMU_result_transfer
|
||||
from vlmeval.dataset.video_dataset_config import supported_video_datasets
|
||||
from vlmeval.dataset import build_dataset
|
||||
from vlmeval.inference import infer_data_job
|
||||
from vlmeval.inference_video import infer_data_job_video
|
||||
from vlmeval.inference_mt import infer_data_job_mt
|
||||
from vlmeval.smp import *
|
||||
from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer
|
||||
|
||||
|
||||
def build_model_from_config(cfg, model_name):
|
||||
import vlmeval.api
|
||||
import vlmeval.vlm
|
||||
config = cp.deepcopy(cfg[model_name])
|
||||
if config == {}:
|
||||
return supported_VLM[model_name]()
|
||||
assert 'class' in config
|
||||
cls_name = config.pop('class')
|
||||
if hasattr(vlmeval.api, cls_name):
|
||||
return getattr(vlmeval.api, cls_name)(**config)
|
||||
elif hasattr(vlmeval.vlm, cls_name):
|
||||
return getattr(vlmeval.vlm, cls_name)(**config)
|
||||
else:
|
||||
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.api` or `vlmeval.vlm`')
|
||||
|
||||
|
||||
def build_dataset_from_config(cfg, dataset_name):
|
||||
import vlmeval.dataset
|
||||
import inspect
|
||||
config = cp.deepcopy(cfg[dataset_name])
|
||||
if config == {}:
|
||||
return supported_video_datasets[dataset_name]()
|
||||
assert 'class' in config
|
||||
cls_name = config.pop('class')
|
||||
if hasattr(vlmeval.dataset, cls_name):
|
||||
cls = getattr(vlmeval.dataset, cls_name)
|
||||
sig = inspect.signature(cls.__init__)
|
||||
valid_params = {k: v for k, v in config.items() if k in sig.parameters}
|
||||
if valid_params.get('fps', 0) > 0 and valid_params.get('nframe', 0) > 0:
|
||||
raise ValueError('fps and nframe should not be set at the same time')
|
||||
if valid_params.get('fps', 0) <= 0 and valid_params.get('nframe', 0) <= 0:
|
||||
raise ValueError('fps and nframe should be set at least one valid value')
|
||||
return cls(**valid_params)
|
||||
else:
|
||||
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.dataset`')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data', type=str, nargs='+', required=True)
|
||||
parser.add_argument('--model', type=str, nargs='+', required=True)
|
||||
parser.add_argument('--work-dir', type=str, default='.', help='select the output directory')
|
||||
help_msg = """\
|
||||
You can launch the evaluation by setting either --data and --model or --config.
|
||||
|
||||
--data and --model:
|
||||
Each Arg should be a list of strings, specifying the names of datasets and models.
|
||||
To find all supported model names, please refer to the `vlmeval/config.py` of check the output of the command \
|
||||
`vlmutil mlist all` in the terminal (you should first have vlmeval installed).
|
||||
To find all supported dataset names, please refer to the `vlmeval/dataset/__init__.py` file. The python script \
|
||||
to print all supported dataset names is as follows:
|
||||
```python
|
||||
from vlmeval.dataset import SUPPORTED_DATASETS
|
||||
print(SUPPORTED_DATASETS)
|
||||
```
|
||||
or you can check the output of the command `vlmutil dlist all` in the terminal.
|
||||
To find all supported video dataset default settings, please refer to the \
|
||||
`vlmeval/dataset/video_dataset_config.py` file.
|
||||
|
||||
--config:
|
||||
Launch the evaluation by specifying the path to the config json file. Sample Json Content:
|
||||
```json
|
||||
{
|
||||
"model": {
|
||||
"GPT4o_20240806_T00_HIGH": {
|
||||
"class": "GPT4V",
|
||||
"model": "gpt-4o-2024-08-06",
|
||||
"temperature": 0,
|
||||
"img_detail": "high"
|
||||
},
|
||||
"GPT4o_20240806_T10_Low": {
|
||||
"class": "GPT4V",
|
||||
"model": "gpt-4o-2024-08-06",
|
||||
"temperature": 1.0,
|
||||
"img_detail": "low"
|
||||
},
|
||||
"GPT4o_20241120": {}
|
||||
},
|
||||
"data": {
|
||||
"MME-RealWorld-Lite": {
|
||||
"class": "MMERealWorld",
|
||||
"dataset": "MME-RealWorld-Lite"
|
||||
},
|
||||
"MMBench_DEV_EN_V11": {
|
||||
"class": "ImageMCQDataset",
|
||||
"dataset": "MMBench_DEV_EN_V11"
|
||||
},
|
||||
"MMBench_Video_8frame_nopack": {},
|
||||
"Video-MME_16frame_subs": {
|
||||
"class": "VideoMME",
|
||||
"dataset": "Video-MME",
|
||||
"nframe": 16,
|
||||
"use_subtitle": true,
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
Currently, only `model` and `data` are supported fields. The content of each field is a dictionary.
|
||||
For `model`, the key is the name of the model, and the value is a dictionary containing the following keys:
|
||||
- `class`: The class name of the model, which should be a class in `vlmeval.vlm` or `vlmeval.api`.
|
||||
- Other keys are specific to the model, please refer to the corresponding class.
|
||||
- Tip: The defined model in the `supported_VLM` of `vlmeval/config.py` can be used as a shortcut.
|
||||
For `data`, the key is the name of the dataset (should be the same as the `dataset` field in most cases, \
|
||||
except for video datasets), and the value is a dictionary containing the following keys:
|
||||
- `class`: The class name of the dataset, which should be a class in `vlmeval.dataset`.
|
||||
- `dataset`: The name of the dataset, which should be a string that is accepted by the `dataset` argument of the \
|
||||
corresponding class.
|
||||
- Other keys are specific to the dataset, please refer to the corresponding class.
|
||||
- Tip: The defined dataset in the `supported_video_datasets` of `vlmeval/dataset/video_dataset_config.py` \
|
||||
can be used as a shortcut.
|
||||
|
||||
The keys in the `model` and `data` fields will be used for naming the prediction files and evaluation results.
|
||||
When launching with `--config`, args for API VLMs, such as `--retry`, `--verbose`, will be ignored.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=help_msg, formatter_class=argparse.RawTextHelpFormatter)
|
||||
# Essential Args, Setting the Names of Datasets and Models
|
||||
parser.add_argument('--data', type=str, nargs='+', help='Names of Datasets')
|
||||
parser.add_argument('--model', type=str, nargs='+', help='Names of Models')
|
||||
parser.add_argument('--config', type=str, help='Path to the Config Json File')
|
||||
# Work Dir
|
||||
parser.add_argument('--work-dir', type=str, default='./outputs', help='select the output directory')
|
||||
# Infer + Eval or Infer Only
|
||||
parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer'])
|
||||
parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling')
|
||||
# API Kwargs, Apply to API VLMs and Judge API LLMs
|
||||
parser.add_argument('--api_nproc', type=int, default=4, help='Parallel API calling')
|
||||
parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
|
||||
# Explicitly Set the Judge Model
|
||||
parser.add_argument('--judge', type=str, default=None)
|
||||
parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ')
|
||||
# Logging Utils
|
||||
parser.add_argument('--verbose', action='store_true')
|
||||
parser.add_argument('--rerun', action='store_true')
|
||||
# Configuration for Resume
|
||||
# Ignore: will not rerun failed VLM inference
|
||||
parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ')
|
||||
# Reuse: will reuse the existing prediction files
|
||||
parser.add_argument('--reuse', action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
logger = get_logger('RUN')
|
||||
|
||||
rank, world_size = get_rank_and_world_size()
|
||||
args = parse_args()
|
||||
assert len(args.data), '--data should be a list of data files'
|
||||
use_config, cfg = False, None
|
||||
if args.config is not None:
|
||||
assert args.data is None and args.model is None, '--data and --model should not be set when using --config'
|
||||
use_config, cfg = True, load(args.config)
|
||||
args.model = list(cfg['model'].keys())
|
||||
args.data = list(cfg['data'].keys())
|
||||
else:
|
||||
assert len(args.data), '--data should be a list of data files'
|
||||
|
||||
if args.retry is not None:
|
||||
if rank == 0:
|
||||
if not args.reuse:
|
||||
logger.warning('--reuse is not set, will not reuse previous (before one day) temporary files')
|
||||
else:
|
||||
logger.warning('--reuse is set, will reuse the latest prediction & temporary pickle files')
|
||||
|
||||
if 'MMEVAL_ROOT' in os.environ:
|
||||
args.work_dir = os.environ['MMEVAL_ROOT']
|
||||
|
||||
if not use_config:
|
||||
for k, v in supported_VLM.items():
|
||||
if hasattr(v, 'keywords') and 'retry' in v.keywords:
|
||||
if hasattr(v, 'keywords') and 'retry' in v.keywords and args.retry is not None:
|
||||
v.keywords['retry'] = args.retry
|
||||
supported_VLM[k] = v
|
||||
if hasattr(v, 'keywords') and 'verbose' in v.keywords:
|
||||
if hasattr(v, 'keywords') and 'verbose' in v.keywords and args.verbose is not None:
|
||||
v.keywords['verbose'] = args.verbose
|
||||
supported_VLM[k] = v
|
||||
|
||||
rank, world_size = get_rank_and_world_size()
|
||||
if world_size > 1:
|
||||
local_rank = os.environ.get('LOCAL_RANK', 0)
|
||||
torch.cuda.set_device(int(local_rank))
|
||||
dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=10800))
|
||||
dist.init_process_group(
|
||||
backend='nccl',
|
||||
timeout=datetime.timedelta(seconds=int(os.environ.get('DIST_TIMEOUT', 3600)))
|
||||
)
|
||||
|
||||
for _, model_name in enumerate(args.model):
|
||||
model = None
|
||||
date, commit_id = timestr('day'), githash(digits=8)
|
||||
eval_id = f"T{date}_G{commit_id}"
|
||||
|
||||
pred_root = osp.join(args.work_dir, model_name)
|
||||
os.makedirs(pred_root, exist_ok=True)
|
||||
pred_root = osp.join(args.work_dir, model_name, eval_id)
|
||||
pred_root_meta = osp.join(args.work_dir, model_name)
|
||||
os.makedirs(pred_root_meta, exist_ok=True)
|
||||
|
||||
prev_pred_roots = ls(osp.join(args.work_dir, model_name), mode='dir')
|
||||
if len(prev_pred_roots) and args.reuse:
|
||||
prev_pred_roots.sort()
|
||||
|
||||
if not osp.exists(pred_root):
|
||||
os.makedirs(pred_root, exist_ok=True)
|
||||
|
||||
if use_config:
|
||||
model = build_model_from_config(cfg['model'], model_name)
|
||||
|
||||
for _, dataset_name in enumerate(args.data):
|
||||
custom_flag = False
|
||||
try:
|
||||
result_file_base = f'{model_name}_{dataset_name}.xlsx'
|
||||
|
||||
if dataset_name not in dataset_URLs:
|
||||
dataset_name = abbr2full(dataset_name)
|
||||
|
||||
if dataset_name not in dataset_URLs:
|
||||
logger.warning(f'Dataset {dataset_name} is not officially supported. ')
|
||||
file_path = osp.join(LMUDataRoot(), f'{dataset_name}.tsv')
|
||||
if not osp.exists(file_path):
|
||||
logger.error(f'Cannot find the local dataset {dataset_name}. ')
|
||||
continue
|
||||
if use_config:
|
||||
if world_size > 1:
|
||||
if rank == 0:
|
||||
dataset = build_dataset_from_config(cfg['data'], dataset_name)
|
||||
dist.barrier()
|
||||
dataset = build_dataset_from_config(cfg['data'], dataset_name)
|
||||
if dataset is None:
|
||||
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
|
||||
continue
|
||||
else:
|
||||
custom_flag = True
|
||||
dataset_kwargs = {}
|
||||
if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']:
|
||||
dataset_kwargs['model'] = model_name
|
||||
|
||||
result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx'
|
||||
if osp.exists(result_file) and args.rerun:
|
||||
os.system(f'rm {pred_root}/{model_name}_{dataset_name}_*')
|
||||
# If distributed, first build the dataset on the main process for doing preparation works
|
||||
if world_size > 1:
|
||||
if rank == 0:
|
||||
dataset = build_dataset(dataset_name, **dataset_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
if model is None:
|
||||
model = model_name # which is only a name
|
||||
dataset = build_dataset(dataset_name, **dataset_kwargs)
|
||||
if dataset is None:
|
||||
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
|
||||
continue
|
||||
|
||||
model = infer_data_job(
|
||||
model,
|
||||
work_dir=pred_root,
|
||||
model_name=model_name,
|
||||
dataset_name=dataset_name,
|
||||
verbose=args.verbose,
|
||||
api_nproc=args.nproc,
|
||||
ignore_failed=args.ignore)
|
||||
# Handling Multi-Turn Dataset
|
||||
if dataset.TYPE == 'MT':
|
||||
result_file_base = result_file_base.replace('.xlsx', '.tsv')
|
||||
|
||||
if rank == 0:
|
||||
if dataset_name in ['MMMU_TEST']:
|
||||
result_json = MMMU_result_transfer(result_file)
|
||||
logger.info(f'Transfer MMMU_TEST result to json for official evaluation, json file saved in {result_json}') # noqa: E501
|
||||
continue
|
||||
result_file = osp.join(pred_root, result_file_base)
|
||||
|
||||
if dataset_name in [
|
||||
'MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMBench', 'MMBench_CN'
|
||||
'MMBench_TEST_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_V11', 'MMBench_CN_V11'
|
||||
]:
|
||||
if not MMBenchOfficialServer(dataset_name):
|
||||
logger.error(
|
||||
f'Can not evaluate {dataset_name} on non-official servers, '
|
||||
'will skip the evaluation. '
|
||||
)
|
||||
continue
|
||||
# Reuse the previous prediction file if exists
|
||||
if rank == 0 and len(prev_pred_roots):
|
||||
prev_result_file = None
|
||||
prev_pkl_file_list = []
|
||||
for root in prev_pred_roots[::-1]:
|
||||
if osp.exists(osp.join(root, result_file_base)):
|
||||
prev_result_file = osp.join(root, result_file_base)
|
||||
break
|
||||
elif commit_id in root and len(ls(root)) and root != pred_root:
|
||||
temp_files = ls(root, match=[dataset_name, '.pkl'])
|
||||
if len(temp_files):
|
||||
prev_pkl_file_list.extend(temp_files)
|
||||
break
|
||||
if not args.reuse:
|
||||
prev_result_file = None
|
||||
prev_pkl_file_list = []
|
||||
if prev_result_file is not None:
|
||||
logger.warning(
|
||||
f'--reuse is set, will reuse the prediction file {prev_result_file}.')
|
||||
if prev_result_file != result_file:
|
||||
shutil.copy(prev_result_file, result_file)
|
||||
elif len(prev_pkl_file_list):
|
||||
for fname in prev_pkl_file_list:
|
||||
target_path = osp.join(pred_root, osp.basename(fname))
|
||||
if not osp.exists(target_path):
|
||||
shutil.copy(fname, target_path)
|
||||
logger.info(f'--reuse is set, will reuse the prediction pickle file {fname}.')
|
||||
else:
|
||||
logger.warning(f'File already exists: {target_path}')
|
||||
|
||||
judge_kwargs = {
|
||||
'nproc': args.nproc,
|
||||
'verbose': args.verbose,
|
||||
}
|
||||
if args.retry is not None:
|
||||
judge_kwargs['retry'] = args.retry
|
||||
if args.judge is not None:
|
||||
judge_kwargs['model'] = args.judge
|
||||
else:
|
||||
if DATASET_TYPE(dataset_name) in ['multi-choice', 'Y/N']:
|
||||
judge_kwargs['model'] = 'chatgpt-0613'
|
||||
elif listinstr(['MMVet', 'MathVista', 'LLaVABench'], dataset_name):
|
||||
judge_kwargs['model'] = 'gpt-4-turbo'
|
||||
if 'OPENAI_API_KEY_JUDGE' in os.environ and len(os.environ['OPENAI_API_KEY_JUDGE']):
|
||||
judge_kwargs['key'] = os.environ['OPENAI_API_KEY_JUDGE']
|
||||
if 'OPENAI_API_BASE_JUDGE' in os.environ and len(os.environ['OPENAI_API_BASE_JUDGE']):
|
||||
judge_kwargs['api_base'] = os.environ['OPENAI_API_BASE_JUDGE']
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0 and args.mode == 'all':
|
||||
if DATASET_TYPE(dataset_name) == 'multi-choice':
|
||||
dataset_name = 'default' if custom_flag else dataset_name
|
||||
multiple_choice_eval(
|
||||
result_file,
|
||||
dataset=dataset_name,
|
||||
**judge_kwargs)
|
||||
elif DATASET_TYPE(dataset_name) == 'Y/N':
|
||||
YOrN_eval(
|
||||
result_file,
|
||||
dataset=dataset_name,
|
||||
**judge_kwargs)
|
||||
elif DATASET_TYPE(dataset_name) == 'Caption':
|
||||
COCO_eval(result_file)
|
||||
elif dataset_name == 'MMVet':
|
||||
MMVet_eval(result_file, **judge_kwargs)
|
||||
elif dataset_name == 'OCRBench':
|
||||
OCRBench_eval(result_file)
|
||||
elif listinstr(['OCRVQA', 'TextVQA', 'ChartQA', 'DocVQA', 'InfoVQA'], dataset_name):
|
||||
VQAEval(result_file, dataset_name)
|
||||
elif listinstr(['MathVista'], dataset_name):
|
||||
MathVista_eval(result_file, **judge_kwargs)
|
||||
elif listinstr(['LLaVABench'], dataset_name):
|
||||
LLaVABench_eval(result_file, **judge_kwargs)
|
||||
if model is None:
|
||||
model = model_name # which is only a name
|
||||
|
||||
# Perform the Inference
|
||||
if dataset.MODALITY == 'VIDEO':
|
||||
model = infer_data_job_video(
|
||||
model,
|
||||
work_dir=pred_root,
|
||||
model_name=model_name,
|
||||
dataset=dataset,
|
||||
result_file_name=result_file_base,
|
||||
verbose=args.verbose,
|
||||
api_nproc=args.api_nproc)
|
||||
elif dataset.TYPE == 'MT':
|
||||
model = infer_data_job_mt(
|
||||
model,
|
||||
work_dir=pred_root,
|
||||
model_name=model_name,
|
||||
dataset=dataset,
|
||||
verbose=args.verbose,
|
||||
api_nproc=args.api_nproc,
|
||||
ignore_failed=args.ignore)
|
||||
else:
|
||||
logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ')
|
||||
model = infer_data_job(
|
||||
model,
|
||||
work_dir=pred_root,
|
||||
model_name=model_name,
|
||||
dataset=dataset,
|
||||
verbose=args.verbose,
|
||||
api_nproc=args.api_nproc,
|
||||
ignore_failed=args.ignore)
|
||||
|
||||
# Set the judge kwargs first before evaluation or dumping
|
||||
|
||||
judge_kwargs = {
|
||||
'nproc': args.api_nproc,
|
||||
'verbose': args.verbose,
|
||||
'retry': args.retry if args.retry is not None else 3
|
||||
}
|
||||
|
||||
if args.retry is not None:
|
||||
judge_kwargs['retry'] = args.retry
|
||||
if args.judge is not None:
|
||||
judge_kwargs['model'] = args.judge
|
||||
else:
|
||||
if dataset.TYPE in ['MCQ', 'Y/N']:
|
||||
judge_kwargs['model'] = 'chatgpt-0125'
|
||||
elif listinstr(['MMVet', 'LLaVABench', 'MMBench-Video'], dataset_name):
|
||||
judge_kwargs['model'] = 'gpt-4-turbo'
|
||||
elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath', 'VL-RewardBench', 'WeMath', 'LogicVista'], dataset_name): # noqa: E501
|
||||
judge_kwargs['model'] = 'gpt-4o-mini'
|
||||
elif listinstr(['MMLongBench', 'MMDU', 'DUDE', 'SLIDEVQA', 'MIA-Bench', 'WildVision'], dataset_name): # noqa: E501
|
||||
judge_kwargs['model'] = 'gpt-4o'
|
||||
|
||||
if rank == 0:
|
||||
logger.info(judge_kwargs)
|
||||
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
# Only Rank 0 handles the evaluation part
|
||||
if rank == 0:
|
||||
# Prepare Submission Files for MMMU_TEST AND MMT-Bench_ALL
|
||||
if dataset_name in ['MMMU_TEST']:
|
||||
result_json = MMMU_result_transfer(result_file)
|
||||
logger.info(f'Transfer MMMU_TEST result to json for official evaluation, '
|
||||
f'json file saved in {result_json}')
|
||||
continue
|
||||
elif 'MMT-Bench_ALL' in dataset_name:
|
||||
submission_file = MMTBench_result_transfer(result_file, **judge_kwargs)
|
||||
logger.info(f'Extract options from prediction of MMT-Bench FULL split for official evaluation '
|
||||
f'(https://eval.ai/web/challenges/challenge-page/2328/overview), '
|
||||
f'submission file saved in {submission_file}')
|
||||
continue
|
||||
|
||||
# Skip the evaluation part if only infer
|
||||
if args.mode == 'infer':
|
||||
continue
|
||||
|
||||
# Skip the evaluation part if the dataset evaluation is not supported or annotations are missing
|
||||
if 'MLLMGuard_DS' in dataset_name:
|
||||
logger.info('The evaluation of MLLMGuard_DS is not supported yet. ')
|
||||
continue
|
||||
elif 'AesBench_TEST' == dataset_name:
|
||||
logger.info(f'The results are saved in {result_file}. '
|
||||
f'Please send it to the AesBench Team via huangyipo@hotmail.com.')
|
||||
continue
|
||||
elif dataset_name in ['DocVQA_TEST', 'InfoVQA_TEST', 'Q-Bench1_TEST', 'A-Bench_TEST']:
|
||||
logger.info(f'{dataset_name} is a test split without ground-truth. '
|
||||
'Thus only the inference part is supported for those datasets. ')
|
||||
continue
|
||||
elif dataset_name in [
|
||||
'MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMBench', 'MMBench_CN',
|
||||
'MMBench_TEST_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_V11', 'MMBench_CN_V11'
|
||||
] and not MMBenchOfficialServer(dataset_name):
|
||||
logger.error(
|
||||
f'Can not evaluate {dataset_name} on non-official servers, will skip the evaluation.')
|
||||
continue
|
||||
|
||||
# Setup the proxy for the evaluation
|
||||
eval_proxy = os.environ.get('EVAL_PROXY', None)
|
||||
old_proxy = os.environ.get('HTTP_PROXY', '')
|
||||
if eval_proxy is not None:
|
||||
proxy_set(eval_proxy)
|
||||
|
||||
# Perform the Evaluation
|
||||
eval_results = dataset.evaluate(result_file, **judge_kwargs)
|
||||
# Display Evaluation Results in Terminal
|
||||
if eval_results is not None:
|
||||
assert isinstance(eval_results, dict) or isinstance(eval_results, pd.DataFrame)
|
||||
logger.info(f'The evaluation of model {model_name} x dataset {dataset_name} has finished! ')
|
||||
logger.info('Evaluation Results:')
|
||||
if isinstance(eval_results, dict):
|
||||
logger.info('\n' + json.dumps(eval_results, indent=4))
|
||||
elif isinstance(eval_results, pd.DataFrame):
|
||||
if len(eval_results) < len(eval_results.columns):
|
||||
eval_results = eval_results.T
|
||||
logger.info('\n' + tabulate(eval_results))
|
||||
|
||||
# Restore the proxy
|
||||
if eval_proxy is not None:
|
||||
proxy_set(old_proxy)
|
||||
|
||||
# Create the symbolic links for the prediction files
|
||||
files = os.listdir(pred_root)
|
||||
files = [x for x in files if (f'{model_name}_{dataset_name}' in x or "status.json" in x)]
|
||||
for f in files:
|
||||
cwd = os.getcwd()
|
||||
file_addr = osp.join(cwd, pred_root, f)
|
||||
link_addr = osp.join(cwd, pred_root_meta, f)
|
||||
if osp.exists(link_addr) or osp.islink(link_addr):
|
||||
os.remove(link_addr)
|
||||
os.symlink(file_addr, link_addr)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Model {model_name} x Dataset {dataset_name} combination failed: {e}, '
|
||||
'skipping this combination.')
|
||||
continue
|
||||
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
if world_size > 1:
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
export PATH=/usr/local/cuda/bin:$PATH
|
||||
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
export OMP_NUM_THREADS=1
|
||||
export timestamp=`date +"%Y%m%d%H%M%S"`
|
||||
export OLD_VERSION='False'
|
||||
export PYTHONPATH=$(dirname $SELF_DIR):$PYTHONPATH
|
||||
|
||||
# gpu consumed
|
||||
# fp16 17-18G
|
||||
# int4 7-8G
|
||||
|
||||
# model to be used
|
||||
# Example: MODELNAME=MiniCPM-Llama3-V-2_5
|
||||
MODELNAME=$1
|
||||
# datasets to be tested
|
||||
# Example: DATALIST="POPE ScienceQA_TEST ChartQA_TEST"
|
||||
DATALIST=$2
|
||||
# test mode, all or infer
|
||||
MODE=$3
|
||||
|
||||
echo "Starting inference with model $MODELNAME on datasets $DATALIST"
|
||||
# run on multi gpus with torchrun command
|
||||
# remember to run twice, the first run may fail
|
||||
torchrun --nproc_per_node=8 run.py --data $DATALIST --model $MODELNAME --mode $MODE
|
||||
torchrun --nproc_per_node=8 run.py --data $DATALIST --model $MODELNAME --mode $MODE
|
||||
# run on single gpu with python command
|
||||
# python run.py --data $DATALIST --model $MODELNAME --verbose --mode $MODE
|
||||
# python run.py --data $DATALIST --model $MODELNAME --verbose --mode $MODE
|
||||
|
||||
ls
|
||||
41
eval_mm/vlmevalkit/scripts/run_inference.sh
Normal file
@@ -0,0 +1,41 @@
|
||||
export PATH=/usr/local/cuda/bin:$PATH
|
||||
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
export OMP_NUM_THREADS=1
|
||||
export timestamp=`date +"%Y%m%d%H%M%S"`
|
||||
export OLD_VERSION='False'
|
||||
export PYTHONPATH=$(dirname $SELF_DIR):$PYTHONPATH
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
# gpu consumed
|
||||
# fp16 17-18G
|
||||
# int4 7-8G
|
||||
|
||||
# model to be used
|
||||
# Example: MODELNAME=MiniCPM-o-2_6
|
||||
MODELNAME=$1
|
||||
# datasets to be tested
|
||||
# Example: DATALIST=MMMU_DEV_VAL
|
||||
DATALIST=$2
|
||||
|
||||
# run on multi gpus with torchrun command
|
||||
# remember to run twice, the first run may fail
|
||||
for DATASET in $DATALIST; do
|
||||
echo "Starting inference with model $MODELNAME on dataset $DATASET"
|
||||
torchrun --master_port 29500 --nproc_per_node=8 run.py --data $DATASET --model $MODELNAME --mode infer --reuse
|
||||
torchrun --master_port 29501 --nproc_per_node=8 run.py --data $DATASET --model $MODELNAME --mode infer --reuse
|
||||
|
||||
# for benchmarks which require gpt for scoring, you need to specify OPENAI_API_BASE and OPENAI_API_KEY in .env file
|
||||
if [[ "$DATASET" == *"MMBench_TEST"*]]; then
|
||||
echo "Skipping evaluation for dataset $DATASET"
|
||||
else
|
||||
echo "Starting evaluation with model $MODELNAME on datasets $DATASET"
|
||||
python run.py --data $DATASET --model $MODELNAME --nproc 16 --verbose
|
||||
fi
|
||||
done
|
||||
|
||||
# run on single gpu with python command
|
||||
# python run.py --data $DATALIST --model $MODELNAME --verbose --mode infer
|
||||
# python run.py --data $DATALIST --model $MODELNAME --verbose --mode infer
|
||||
# echo "Starting evaluation with model $MODELNAME on datasets $DATASET"
|
||||
# python run.py --data $DATASET --model $MODELNAME --nproc 16 --verbose
|
||||
122
eval_mm/vlmevalkit/setup.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import re
|
||||
import sys
|
||||
from os.path import exists
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def parse_requirements(fname='requirements.txt', with_version=True):
|
||||
"""Parse the package dependencies listed in a requirements file but strips
|
||||
specific versioning information.
|
||||
|
||||
Args:
|
||||
fname (str): path to requirements file
|
||||
with_version (bool, default=False): if True include version specs
|
||||
|
||||
Returns:
|
||||
List[str]: list of requirements items
|
||||
|
||||
CommandLine:
|
||||
python -c "import setup; print(setup.parse_requirements())"
|
||||
"""
|
||||
|
||||
require_fpath = fname
|
||||
|
||||
def parse_line(line):
|
||||
"""Parse information from a line in a requirements text file."""
|
||||
if line.startswith('-r '):
|
||||
# Allow specifying requirements in other files
|
||||
target = line.split(' ')[1]
|
||||
for info in parse_require_file(target):
|
||||
yield info
|
||||
else:
|
||||
info = {'line': line}
|
||||
if line.startswith('-e '):
|
||||
info['package'] = line.split('#egg=')[1]
|
||||
elif '@git+' in line:
|
||||
info['package'] = line
|
||||
else:
|
||||
# Remove versioning from the package
|
||||
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
|
||||
parts = re.split(pat, line, maxsplit=1)
|
||||
parts = [p.strip() for p in parts]
|
||||
|
||||
info['package'] = parts[0]
|
||||
if len(parts) > 1:
|
||||
op, rest = parts[1:]
|
||||
if ';' in rest:
|
||||
# Handle platform specific dependencies
|
||||
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
|
||||
version, platform_deps = map(str.strip,
|
||||
rest.split(';'))
|
||||
info['platform_deps'] = platform_deps
|
||||
else:
|
||||
version = rest # NOQA
|
||||
info['version'] = (op, version)
|
||||
yield info
|
||||
|
||||
def parse_require_file(fpath):
|
||||
with open(fpath, 'r') as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
for info in parse_line(line):
|
||||
yield info
|
||||
|
||||
def gen_packages_items():
|
||||
if exists(require_fpath):
|
||||
for info in parse_require_file(require_fpath):
|
||||
parts = [info['package']]
|
||||
if with_version and 'version' in info:
|
||||
parts.extend(info['version'])
|
||||
if not sys.version.startswith('3.4'):
|
||||
# apparently package_deps are broken in 3.4
|
||||
platform_deps = info.get('platform_deps')
|
||||
if platform_deps is not None:
|
||||
parts.append(';' + platform_deps)
|
||||
item = ''.join(parts)
|
||||
yield item
|
||||
|
||||
packages = list(gen_packages_items())
|
||||
return packages
|
||||
|
||||
|
||||
with open('README.md') as f:
|
||||
readme = f.read()
|
||||
|
||||
|
||||
def do_setup():
|
||||
setup(
|
||||
name='vlmeval',
|
||||
version='0.1.0',
|
||||
description='OpenCompass VLM Evaluation Kit',
|
||||
author='Haodong Duan',
|
||||
author_email='dhd.efz@gmail.com',
|
||||
maintainer='Haodong Duan',
|
||||
maintainer_email='dhd.efz@gmail.com',
|
||||
long_description=readme,
|
||||
long_description_content_type='text/markdown',
|
||||
cmdclass={},
|
||||
install_requires=parse_requirements('requirements.txt'),
|
||||
setup_requires=[],
|
||||
python_requires='>=3.7.0',
|
||||
packages=find_packages(exclude=[
|
||||
'test*',
|
||||
'paper_test*',
|
||||
]),
|
||||
keywords=['AI', 'NLP', 'in-context learning'],
|
||||
entry_points={
|
||||
'console_scripts': ['vlmutil = vlmeval:cli']
|
||||
},
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
do_setup()
|
||||
@@ -5,9 +5,12 @@ except ImportError:
|
||||
|
||||
from .smp import *
|
||||
from .api import *
|
||||
from .evaluate import *
|
||||
from .dataset import *
|
||||
from .utils import *
|
||||
from .vlm import *
|
||||
from .config import *
|
||||
from .tools import cli
|
||||
|
||||
load_env()
|
||||
|
||||
__version__ = '0.2rc1'
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .gpt import OpenAIWrapper, GPT4V
|
||||
from .gpt_int import OpenAIWrapperInternal, GPT4V_Internal
|
||||
|
||||
__all__ = [
|
||||
'OpenAIWrapper', 'OpenAIWrapperInternal', 'GPT4V', 'GPT4V_Internal'
|
||||
'OpenAIWrapper', 'GPT4V',
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@ import random as rd
|
||||
from abc import abstractmethod
|
||||
import os.path as osp
|
||||
import copy as cp
|
||||
from ..smp import get_logger, parse_file
|
||||
from ..smp import get_logger, parse_file, concat_images_vlmeval, LMUDataRoot, md5, decode_base64_to_image_file
|
||||
|
||||
|
||||
class BaseAPI:
|
||||
@@ -62,12 +62,22 @@ class BaseAPI:
|
||||
Returns:
|
||||
bool: If the API model is working, return True, else return False.
|
||||
"""
|
||||
retry = 3
|
||||
self.old_timeout = None
|
||||
if hasattr(self, 'timeout'):
|
||||
self.old_timeout = self.timeout
|
||||
self.timeout = 120
|
||||
|
||||
retry = 5
|
||||
while retry > 0:
|
||||
ret = self.generate('hello')
|
||||
if ret is not None and ret != '' and self.fail_msg not in ret:
|
||||
if self.old_timeout is not None:
|
||||
self.timeout = self.old_timeout
|
||||
return True
|
||||
retry -= 1
|
||||
|
||||
if self.old_timeout is not None:
|
||||
self.timeout = self.old_timeout
|
||||
return False
|
||||
|
||||
def check_content(self, msgs):
|
||||
@@ -127,6 +137,82 @@ class BaseAPI:
|
||||
else:
|
||||
return None
|
||||
|
||||
# May exceed the context windows size, so try with different turn numbers.
|
||||
def chat_inner(self, inputs, **kwargs):
|
||||
_ = kwargs.pop('dataset', None)
|
||||
while len(inputs):
|
||||
try:
|
||||
return self.generate_inner(inputs, **kwargs)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
self.logger.info(f'{type(e)}: {e}')
|
||||
inputs = inputs[1:]
|
||||
while len(inputs) and inputs[0]['role'] != 'user':
|
||||
inputs = inputs[1:]
|
||||
continue
|
||||
return -1, self.fail_msg + ': ' + 'Failed with all possible conversation turns.', None
|
||||
|
||||
def chat(self, messages, **kwargs1):
|
||||
"""The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
|
||||
assert hasattr(self, 'chat_inner'), 'The API model should has the `chat_inner` method. '
|
||||
for msg in messages:
|
||||
assert isinstance(msg, dict) and 'role' in msg and 'content' in msg, msg
|
||||
assert self.check_content(msg['content']) in ['str', 'dict', 'liststr', 'listdict'], msg
|
||||
msg['content'] = self.preproc_content(msg['content'])
|
||||
# merge kwargs
|
||||
kwargs = cp.deepcopy(self.default_kwargs)
|
||||
kwargs.update(kwargs1)
|
||||
|
||||
answer = None
|
||||
# a very small random delay [0s - 0.5s]
|
||||
T = rd.random() * 0.5
|
||||
time.sleep(T)
|
||||
|
||||
assert messages[-1]['role'] == 'user'
|
||||
|
||||
for i in range(self.retry):
|
||||
try:
|
||||
ret_code, answer, log = self.chat_inner(messages, **kwargs)
|
||||
if ret_code == 0 and self.fail_msg not in answer and answer != '':
|
||||
if self.verbose:
|
||||
print(answer)
|
||||
return answer
|
||||
elif self.verbose:
|
||||
if not isinstance(log, str):
|
||||
try:
|
||||
log = log.text
|
||||
except Exception as e:
|
||||
self.logger.warning(f'Failed to parse {log} as an http response: {str(e)}. ')
|
||||
self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
|
||||
except Exception as err:
|
||||
if self.verbose:
|
||||
self.logger.error(f'An error occured during try {i}: ')
|
||||
self.logger.error(f'{type(err)}: {err}')
|
||||
# delay before each retry
|
||||
T = rd.random() * self.wait * 2
|
||||
time.sleep(T)
|
||||
|
||||
return self.fail_msg if answer in ['', None] else answer
|
||||
|
||||
def preprocess_message_with_role(self, message):
|
||||
system_prompt = ''
|
||||
new_message = []
|
||||
|
||||
for data in message:
|
||||
assert isinstance(data, dict)
|
||||
role = data.pop('role', 'user')
|
||||
if role == 'system':
|
||||
system_prompt += data['value'] + '\n'
|
||||
else:
|
||||
new_message.append(data)
|
||||
|
||||
if system_prompt != '':
|
||||
if self.system_prompt is None:
|
||||
self.system_prompt = system_prompt
|
||||
else:
|
||||
self.system_prompt += '\n' + system_prompt
|
||||
return new_message
|
||||
|
||||
def generate(self, message, **kwargs1):
|
||||
"""The main function to generate the answer. Will call `generate_inner` with the preprocessed input messages.
|
||||
|
||||
@@ -136,6 +222,9 @@ class BaseAPI:
|
||||
Returns:
|
||||
str: The generated answer of the Failed Message if failed to obtain answer.
|
||||
"""
|
||||
if self.check_content(message) == 'listdict':
|
||||
message = self.preprocess_message_with_role(message)
|
||||
|
||||
assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
|
||||
message = self.preproc_content(message)
|
||||
assert message is not None and self.check_content(message) == 'listdict'
|
||||
@@ -162,20 +251,20 @@ class BaseAPI:
|
||||
if not isinstance(log, str):
|
||||
try:
|
||||
log = log.text
|
||||
except:
|
||||
self.logger.warning(f'Failed to parse {log} as an http response. ')
|
||||
except Exception as e:
|
||||
self.logger.warning(f'Failed to parse {log} as an http response: {str(e)}. ')
|
||||
self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
|
||||
except Exception as err:
|
||||
if self.verbose:
|
||||
self.logger.error(f'An error occured during try {i}:')
|
||||
self.logger.error(err)
|
||||
self.logger.error(f'An error occured during try {i}: ')
|
||||
self.logger.error(f'{type(err)}: {err}')
|
||||
# delay before each retry
|
||||
T = rd.random() * self.wait * 2
|
||||
time.sleep(T)
|
||||
|
||||
return self.fail_msg if answer in ['', None] else answer
|
||||
|
||||
def message_to_promptimg(self, message):
|
||||
def message_to_promptimg(self, message, dataset=None):
|
||||
assert not self.INTERLEAVE
|
||||
model_name = self.__class__.__name__
|
||||
import warnings
|
||||
@@ -191,5 +280,10 @@ class BaseAPI:
|
||||
image = [x['value'] for x in message if x['type'] == 'image'][0]
|
||||
else:
|
||||
prompt = '\n'.join([x['value'] if x['type'] == 'text' else '<image>' for x in message])
|
||||
image = [x['value'] for x in message if x['type'] == 'image'][0]
|
||||
if dataset == 'BLINK':
|
||||
image = concat_images_vlmeval(
|
||||
[x['value'] for x in message if x['type'] == 'image'],
|
||||
target_size=512)
|
||||
else:
|
||||
image = [x['value'] for x in message if x['type'] == 'image'][0]
|
||||
return prompt, image
|
||||
|
||||
@@ -10,18 +10,18 @@ APIBASES = {
|
||||
|
||||
def GPT_context_window(model):
|
||||
length_map = {
|
||||
'gpt-4-1106-preview': 128000,
|
||||
'gpt-4-vision-preview': 128000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-4-0613': 8192,
|
||||
'gpt-4-32k-0613': 32768,
|
||||
'gpt-4-turbo-preview': 128000,
|
||||
'gpt-4-1106-preview': 128000,
|
||||
'gpt-4-0125-preview': 128000,
|
||||
'gpt-4-vision-preview': 128000,
|
||||
'gpt-4-turbo': 128000,
|
||||
'gpt-4-turbo-2024-04-09': 128000,
|
||||
'gpt-3.5-turbo': 16385,
|
||||
'gpt-3.5-turbo-0125': 16385,
|
||||
'gpt-3.5-turbo-1106': 16385,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16385,
|
||||
'gpt-3.5-turbo-instruct': 4096,
|
||||
'gpt-3.5-turbo-0613': 4096,
|
||||
'gpt-3.5-turbo-16k-0613': 16385,
|
||||
}
|
||||
if model in length_map:
|
||||
return length_map[model]
|
||||
@@ -38,7 +38,7 @@ class OpenAIWrapper(BaseAPI):
|
||||
retry: int = 5,
|
||||
wait: int = 5,
|
||||
key: str = None,
|
||||
verbose: bool = True,
|
||||
verbose: bool = False,
|
||||
system_prompt: str = None,
|
||||
temperature: float = 0,
|
||||
timeout: int = 60,
|
||||
@@ -46,6 +46,7 @@ class OpenAIWrapper(BaseAPI):
|
||||
max_tokens: int = 1024,
|
||||
img_size: int = 512,
|
||||
img_detail: str = 'low',
|
||||
use_azure: bool = False,
|
||||
**kwargs):
|
||||
|
||||
self.model = model
|
||||
@@ -53,19 +54,43 @@ class OpenAIWrapper(BaseAPI):
|
||||
self.fail_msg = 'Failed to obtain answer via API. '
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
self.use_azure = use_azure
|
||||
|
||||
if 'step-1v' in model:
|
||||
if 'step' in model:
|
||||
env_key = os.environ.get('STEPAI_API_KEY', '')
|
||||
if key is None:
|
||||
key = env_key
|
||||
else:
|
||||
env_key = os.environ.get('OPENAI_API_KEY', '')
|
||||
elif 'yi-vision' in model:
|
||||
env_key = os.environ.get('YI_API_KEY', '')
|
||||
if key is None:
|
||||
key = env_key
|
||||
assert isinstance(key, str) and key.startswith('sk-'), (
|
||||
f'Illegal openai_key {key}. '
|
||||
'Please set the environment variable OPENAI_API_KEY to your openai key. '
|
||||
)
|
||||
elif 'internvl2-pro' in model:
|
||||
env_key = os.environ.get('InternVL2_PRO_KEY', '')
|
||||
if key is None:
|
||||
key = env_key
|
||||
elif 'abab' in model:
|
||||
env_key = os.environ.get('MiniMax_API_KEY', '')
|
||||
if key is None:
|
||||
key = env_key
|
||||
else:
|
||||
if use_azure:
|
||||
env_key = os.environ.get('AZURE_OPENAI_API_KEY', None)
|
||||
assert env_key is not None, 'Please set the environment variable AZURE_OPENAI_API_KEY. '
|
||||
|
||||
if key is None:
|
||||
key = env_key
|
||||
assert isinstance(key, str), (
|
||||
'Please set the environment variable AZURE_OPENAI_API_KEY to your openai key. '
|
||||
)
|
||||
else:
|
||||
env_key = os.environ.get('OPENAI_API_KEY', '')
|
||||
if key is None:
|
||||
key = env_key
|
||||
assert isinstance(key, str) and key.startswith('sk-'), (
|
||||
f'Illegal openai_key {key}. '
|
||||
'Please set the environment variable OPENAI_API_KEY to your openai key. '
|
||||
)
|
||||
|
||||
self.key = key
|
||||
assert img_size > 0 or img_size == -1
|
||||
self.img_size = img_size
|
||||
@@ -75,30 +100,46 @@ class OpenAIWrapper(BaseAPI):
|
||||
|
||||
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
||||
|
||||
if api_base is None:
|
||||
if 'OPENAI_API_BASE' in os.environ and os.environ['OPENAI_API_BASE'] != '':
|
||||
self.logger.error('Environment variable OPENAI_API_BASE is set. Will use it as api_base. ')
|
||||
api_base = os.environ['OPENAI_API_BASE']
|
||||
else:
|
||||
api_base = 'OFFICIAL'
|
||||
if use_azure:
|
||||
api_base_template = (
|
||||
'{endpoint}openai/deployments/{deployment_name}/chat/completions?api-version={api_version}'
|
||||
)
|
||||
endpoint = os.getenv('AZURE_OPENAI_ENDPOINT', None)
|
||||
assert endpoint is not None, 'Please set the environment variable AZURE_OPENAI_ENDPOINT. '
|
||||
deployment_name = os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
||||
assert deployment_name is not None, 'Please set the environment variable AZURE_OPENAI_DEPLOYMENT_NAME. '
|
||||
api_version = os.getenv('OPENAI_API_VERSION', None)
|
||||
assert api_version is not None, 'Please set the environment variable OPENAI_API_VERSION. '
|
||||
|
||||
assert api_base is not None
|
||||
|
||||
if api_base in APIBASES:
|
||||
self.api_base = APIBASES[api_base]
|
||||
elif api_base.startswith('http'):
|
||||
self.api_base = api_base
|
||||
self.api_base = api_base_template.format(
|
||||
endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'),
|
||||
deployment_name=os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME'),
|
||||
api_version=os.getenv('OPENAI_API_VERSION')
|
||||
)
|
||||
else:
|
||||
self.logger.error('Unknown API Base. ')
|
||||
sys.exit(-1)
|
||||
if api_base is None:
|
||||
if 'OPENAI_API_BASE' in os.environ and os.environ['OPENAI_API_BASE'] != '':
|
||||
self.logger.info('Environment variable OPENAI_API_BASE is set. Will use it as api_base. ')
|
||||
api_base = os.environ['OPENAI_API_BASE']
|
||||
else:
|
||||
api_base = 'OFFICIAL'
|
||||
|
||||
assert api_base is not None
|
||||
|
||||
if api_base in APIBASES:
|
||||
self.api_base = APIBASES[api_base]
|
||||
elif api_base.startswith('http'):
|
||||
self.api_base = api_base
|
||||
else:
|
||||
self.logger.error('Unknown API Base. ')
|
||||
raise NotImplementedError
|
||||
|
||||
self.logger.info(f'Using API Base: {self.api_base}; API Key: {self.key}')
|
||||
|
||||
# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
|
||||
# content can be a string or a list of image & text
|
||||
def prepare_inputs(self, inputs):
|
||||
input_msgs = []
|
||||
if self.system_prompt is not None:
|
||||
input_msgs.append(dict(role='system', content=self.system_prompt))
|
||||
def prepare_itlist(self, inputs):
|
||||
assert np.all([isinstance(x, dict) for x in inputs])
|
||||
has_images = np.sum([x['type'] == 'image' for x in inputs])
|
||||
if has_images:
|
||||
content_list = []
|
||||
@@ -111,11 +152,24 @@ class OpenAIWrapper(BaseAPI):
|
||||
b64 = encode_image_to_base64(img, target_size=self.img_size)
|
||||
img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
|
||||
content_list.append(dict(type='image_url', image_url=img_struct))
|
||||
input_msgs.append(dict(role='user', content=content_list))
|
||||
else:
|
||||
assert all([x['type'] == 'text' for x in inputs])
|
||||
text = '\n'.join([x['value'] for x in inputs])
|
||||
input_msgs.append(dict(role='user', content=text))
|
||||
content_list = [dict(type='text', text=text)]
|
||||
return content_list
|
||||
|
||||
def prepare_inputs(self, inputs):
|
||||
input_msgs = []
|
||||
if self.system_prompt is not None:
|
||||
input_msgs.append(dict(role='system', content=self.system_prompt))
|
||||
assert isinstance(inputs, list) and isinstance(inputs[0], dict)
|
||||
assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
|
||||
if 'role' in inputs[0]:
|
||||
assert inputs[-1]['role'] == 'user', inputs[-1]
|
||||
for item in inputs:
|
||||
input_msgs.append(dict(role=item['role'], content=self.prepare_itlist(item['content'])))
|
||||
else:
|
||||
input_msgs.append(dict(role='user', content=self.prepare_itlist(inputs)))
|
||||
return input_msgs
|
||||
|
||||
def generate_inner(self, inputs, **kwargs) -> str:
|
||||
@@ -123,17 +177,24 @@ class OpenAIWrapper(BaseAPI):
|
||||
temperature = kwargs.pop('temperature', self.temperature)
|
||||
max_tokens = kwargs.pop('max_tokens', self.max_tokens)
|
||||
|
||||
context_window = GPT_context_window(self.model)
|
||||
max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
|
||||
if 0 < max_tokens <= 100:
|
||||
self.logger.warning(
|
||||
'Less than 100 tokens left, '
|
||||
'may exceed the context window with some additional meta symbols. '
|
||||
)
|
||||
if max_tokens <= 0:
|
||||
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
|
||||
# context_window = GPT_context_window(self.model)
|
||||
# new_max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
|
||||
# if 0 < new_max_tokens <= 100 and new_max_tokens < max_tokens:
|
||||
# self.logger.warning(
|
||||
# 'Less than 100 tokens left, '
|
||||
# 'may exceed the context window with some additional meta symbols. '
|
||||
# )
|
||||
# if new_max_tokens <= 0:
|
||||
# return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
|
||||
# max_tokens = new_max_tokens
|
||||
|
||||
headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.key}'}
|
||||
# Will send request if use Azure, dk how to use openai client for it
|
||||
if self.use_azure:
|
||||
headers = {'Content-Type': 'application/json', 'api-key': self.key}
|
||||
elif 'internvl2-pro' in self.model:
|
||||
headers = {'Content-Type': 'application/json', 'Authorization': self.key}
|
||||
else:
|
||||
headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.key}'}
|
||||
payload = dict(
|
||||
model=self.model,
|
||||
messages=input_msgs,
|
||||
@@ -141,38 +202,66 @@ class OpenAIWrapper(BaseAPI):
|
||||
n=1,
|
||||
temperature=temperature,
|
||||
**kwargs)
|
||||
response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
|
||||
response = requests.post(
|
||||
self.api_base,
|
||||
headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
|
||||
ret_code = response.status_code
|
||||
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
||||
answer = self.fail_msg
|
||||
try:
|
||||
resp_struct = json.loads(response.text)
|
||||
answer = resp_struct['choices'][0]['message']['content'].strip()
|
||||
except:
|
||||
pass
|
||||
except Exception as err:
|
||||
if self.verbose:
|
||||
self.logger.error(f'{type(err)}: {err}')
|
||||
self.logger.error(response.text if hasattr(response, 'text') else response)
|
||||
|
||||
return ret_code, answer, response
|
||||
|
||||
def get_image_token_len(self, img_path, detail='low'):
|
||||
import math
|
||||
if detail == 'low':
|
||||
return 85
|
||||
|
||||
im = Image.open(img_path)
|
||||
height, width = im.size
|
||||
if width > 1024 or height > 1024:
|
||||
if width > height:
|
||||
height = int(height * 1024 / width)
|
||||
width = 1024
|
||||
else:
|
||||
width = int(width * 1024 / height)
|
||||
height = 1024
|
||||
|
||||
h = math.ceil(height / 512)
|
||||
w = math.ceil(width / 512)
|
||||
total = 85 + 170 * h * w
|
||||
return total
|
||||
|
||||
def get_token_len(self, inputs) -> int:
|
||||
import tiktoken
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(self.model)
|
||||
except:
|
||||
enc = tiktoken.encoding_for_model('gpt-4')
|
||||
except Exception as err:
|
||||
if 'gpt' in self.model.lower():
|
||||
if self.verbose:
|
||||
self.logger.warning(f'{type(err)}: {err}')
|
||||
enc = tiktoken.encoding_for_model('gpt-4')
|
||||
else:
|
||||
return 0
|
||||
assert isinstance(inputs, list)
|
||||
tot = 0
|
||||
for item in inputs:
|
||||
if item['type'] == 'text':
|
||||
if 'role' in item:
|
||||
tot += self.get_token_len(item['content'])
|
||||
elif item['type'] == 'text':
|
||||
tot += len(enc.encode(item['value']))
|
||||
elif item['type'] == 'image':
|
||||
tot += 85
|
||||
if self.img_detail == 'high':
|
||||
img = Image.open(item['value'])
|
||||
npatch = np.ceil(img.size[0] / 512) * np.ceil(img.size[1] / 512)
|
||||
tot += npatch * 170
|
||||
tot += self.get_image_token_len(item['value'], detail=self.img_detail)
|
||||
return tot
|
||||
|
||||
|
||||
class GPT4V(OpenAIWrapper):
|
||||
|
||||
def generate(self, message, dataset=None):
|
||||
return super(GPT4V, self).generate(message)
|
||||
return super(GPT4V, self).generate(message)
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import json
|
||||
import warnings
|
||||
import requests
|
||||
from ..smp import *
|
||||
from .gpt import GPT_context_window, OpenAIWrapper
|
||||
|
||||
url = 'http://ecs.sv.us.alles-apin.openxlab.org.cn/v1/openai/v2/text/chat'
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
|
||||
class OpenAIWrapperInternal(OpenAIWrapper):
|
||||
|
||||
is_api: bool = True
|
||||
|
||||
def __init__(self,
|
||||
model: str = 'gpt-3.5-turbo-0613',
|
||||
retry: int = 5,
|
||||
wait: int = 3,
|
||||
verbose: bool = True,
|
||||
system_prompt: str = None,
|
||||
temperature: float = 0,
|
||||
timeout: int = 60,
|
||||
max_tokens: int = 1024,
|
||||
img_size: int = 512,
|
||||
img_detail: str = 'low',
|
||||
**kwargs):
|
||||
|
||||
self.model = model
|
||||
if 'KEYS' in os.environ and osp.exists(os.environ['KEYS']):
|
||||
keys = load(os.environ['KEYS'])
|
||||
headers['alles-apin-token'] = keys.get('alles-apin-token', '')
|
||||
elif 'ALLES' in os.environ:
|
||||
headers['alles-apin-token'] = os.environ['ALLES']
|
||||
self.headers = headers
|
||||
self.temperature = temperature
|
||||
self.timeout = timeout
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
assert img_size > 0 or img_size == -1
|
||||
self.img_size = img_size
|
||||
assert img_detail in ['high', 'low']
|
||||
self.img_detail = img_detail
|
||||
|
||||
super(OpenAIWrapper, self).__init__(
|
||||
wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
||||
|
||||
def generate_inner(self, inputs, **kwargs) -> str:
|
||||
input_msgs = self.prepare_inputs(inputs)
|
||||
|
||||
temperature = kwargs.pop('temperature', self.temperature)
|
||||
max_tokens = kwargs.pop('max_tokens', self.max_tokens)
|
||||
|
||||
# Held out 100 tokens as buffer
|
||||
context_window = GPT_context_window(self.model)
|
||||
max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
|
||||
if 0 < max_tokens <= 100:
|
||||
print('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
|
||||
if max_tokens <= 0:
|
||||
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
|
||||
|
||||
payload = dict(
|
||||
model=self.model,
|
||||
messages=input_msgs,
|
||||
max_tokens=max_tokens,
|
||||
n=1,
|
||||
stop=None,
|
||||
timeout=self.timeout,
|
||||
temperature=temperature,
|
||||
**kwargs)
|
||||
|
||||
response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
|
||||
ret_code = response.status_code
|
||||
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
||||
|
||||
answer = self.fail_msg
|
||||
try:
|
||||
resp_struct = json.loads(response.text)
|
||||
assert resp_struct['msg'] == 'ok' and resp_struct['msgCode'] == '10000', resp_struct
|
||||
answer = resp_struct['data']['choices'][0]['message']['content'].strip()
|
||||
except:
|
||||
pass
|
||||
return ret_code, answer, response
|
||||
|
||||
|
||||
class GPT4V_Internal(OpenAIWrapperInternal):
|
||||
|
||||
def generate(self, message, dataset=None):
|
||||
return super(GPT4V_Internal, self).generate(message)
|
||||
@@ -2,18 +2,19 @@ from vlmeval.vlm import *
|
||||
from vlmeval.api import *
|
||||
from functools import partial
|
||||
|
||||
ungrouped = {
|
||||
'MiniCPM-V':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5':partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
minicpm_series = {
|
||||
'MiniCPM-V': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
|
||||
'MiniCPM-V-2': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
|
||||
'MiniCPM-Llama3-V-2_5': partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
|
||||
'MiniCPM-V-2_6': partial(MiniCPM_V_2_6, model_path='openbmb/MiniCPM-V-2_6'),
|
||||
'MiniCPM-o-2_6': partial(MiniCPM_o_2_6, model_path='openbmb/MiniCPM-o-2_6'),
|
||||
}
|
||||
|
||||
supported_VLM = {}
|
||||
|
||||
model_groups = [
|
||||
ungrouped
|
||||
minicpm_series
|
||||
]
|
||||
|
||||
for grp in model_groups:
|
||||
supported_VLM.update(grp)
|
||||
|
||||
|
||||
237
eval_mm/vlmevalkit/vlmeval/dataset/__init__.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import warnings
|
||||
|
||||
from .image_base import img_root_map, ImageBaseDataset
|
||||
from .image_caption import ImageCaptionDataset
|
||||
from .image_yorn import ImageYORNDataset
|
||||
from .image_mcq import (
|
||||
ImageMCQDataset, MMMUDataset, CustomMCQDataset, MUIRDataset, GMAIMMBenchDataset, MMERealWorld, HRBenchDataset,
|
||||
NaturalBenchDataset
|
||||
)
|
||||
from .image_mt import MMDUDataset
|
||||
from .image_vqa import (
|
||||
ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, TableVQABench,
|
||||
CustomVQADataset, CRPE, MathVerse, OlympiadBench, QSpatial, VizWiz, MMNIAH, WeMath, LogicVista
|
||||
)
|
||||
|
||||
from .image_ccocr import CCOCRDataset
|
||||
from .text_mcq import CustomTextMCQDataset, TextMCQDataset
|
||||
|
||||
from .vcr import VCRDataset
|
||||
from .mmlongbench import MMLongBench
|
||||
from .dude import DUDE
|
||||
from .slidevqa import SlideVQA
|
||||
from .vl_rewardbench import VLRewardBench
|
||||
|
||||
from .mmbench_video import MMBenchVideo
|
||||
from .videomme import VideoMME
|
||||
from .mvbench import MVBench, MVBench_MP4
|
||||
from .mlvu import MLVU, MLVU_MCQ, MLVU_OpenEnded
|
||||
from .tempcompass import TempCompass, TempCompass_Captioning, TempCompass_MCQ, TempCompass_YorN
|
||||
from .longvideobench import LongVideoBench
|
||||
from .video_concat_dataset import ConcatVideoDataset
|
||||
from .mmgenbench import MMGenBench
|
||||
from .cgbench import CGBench_MCQ_Grounding_Mini, CGBench_OpenEnded_Mini, CGBench_MCQ_Grounding, CGBench_OpenEnded
|
||||
|
||||
from .miabench import MIABench
|
||||
from .cmmmu import CMMMU
|
||||
from .wildvision import WildVision
|
||||
from .mmmath import MMMath
|
||||
from .dynamath import Dynamath
|
||||
from .utils import *
|
||||
from .video_dataset_config import *
|
||||
from ..smp import *
|
||||
|
||||
|
||||
class ConcatDataset(ImageBaseDataset):
|
||||
# This dataset takes multiple dataset names as input and aggregate them into a single dataset.
|
||||
# Each single dataset should not have a field named `SUB_DATASET`
|
||||
|
||||
DATASET_SETS = {
|
||||
'MMMB': ['MMMB_ar', 'MMMB_cn', 'MMMB_en', 'MMMB_pt', 'MMMB_ru', 'MMMB_tr'],
|
||||
'MTL_MMBench_DEV': [
|
||||
'MMBench_dev_ar', 'MMBench_dev_cn', 'MMBench_dev_en',
|
||||
'MMBench_dev_pt', 'MMBench_dev_ru', 'MMBench_dev_tr'
|
||||
]
|
||||
}
|
||||
|
||||
def __init__(self, dataset):
|
||||
datasets = self.DATASET_SETS[dataset]
|
||||
self.dataset_map = {}
|
||||
# The name of the compliation
|
||||
self.dataset_name = dataset
|
||||
self.datasets = datasets
|
||||
for dname in datasets:
|
||||
dataset = build_dataset(dname)
|
||||
assert dataset is not None, dataset
|
||||
self.dataset_map[dname] = dataset
|
||||
TYPES = [x.TYPE for x in self.dataset_map.values()]
|
||||
MODALITIES = [x.MODALITY for x in self.dataset_map.values()]
|
||||
assert np.all([x == TYPES[0] for x in TYPES]), (datasets, TYPES)
|
||||
assert np.all([x == MODALITIES[0] for x in MODALITIES]), (datasets, MODALITIES)
|
||||
self.TYPE = TYPES[0]
|
||||
self.MODALITY = MODALITIES[0]
|
||||
data_all = []
|
||||
for dname in datasets:
|
||||
data = self.dataset_map[dname].data
|
||||
data['SUB_DATASET'] = [dname] * len(data)
|
||||
data_new = localize_df(data, dname, nproc=16)
|
||||
data_all.append(data_new)
|
||||
|
||||
data = pd.concat(data_all)
|
||||
data['original_index'] = data.pop('index')
|
||||
data['index'] = np.arange(len(data))
|
||||
self.data = data
|
||||
|
||||
def build_prompt(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
idx = line['original_index']
|
||||
dname = line['SUB_DATASET']
|
||||
org_data = self.dataset_map[dname].data
|
||||
org_line = cp.deepcopy(org_data[org_data['index'] == idx]).iloc[0]
|
||||
return self.dataset_map[dname].build_prompt(org_line)
|
||||
|
||||
def dump_image(self, line):
|
||||
# Assert all images are pre-dumped
|
||||
assert 'image' not in line
|
||||
assert 'image_path' in line
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
return tgt_path
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return list(cls.DATASET_SETS)
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
suffix = eval_file.split('.')[-1]
|
||||
# First, split the eval_file by dataset
|
||||
data_all = load(eval_file)
|
||||
for dname in self.datasets:
|
||||
tgt = eval_file.replace(self.dataset_name, dname)
|
||||
data_sub = data_all[data_all['SUB_DATASET'] == dname]
|
||||
data_sub.pop('index')
|
||||
data_sub['index'] = data_sub.pop('original_index')
|
||||
data_sub.pop('SUB_DATASET')
|
||||
dump(data_sub, tgt)
|
||||
# Then, evaluate each dataset separately
|
||||
results_all = []
|
||||
for dname in self.datasets:
|
||||
tgt = eval_file.replace(self.dataset_name, dname)
|
||||
res = self.dataset_map[dname].evaluate(tgt, **judge_kwargs)
|
||||
assert isinstance(res, pd.DataFrame)
|
||||
res['DATASET'] = [dname] * len(res)
|
||||
results_all.append(res)
|
||||
result = pd.concat(results_all)
|
||||
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
dump(result, score_file)
|
||||
return result
|
||||
|
||||
|
||||
# Add new supported dataset class here
|
||||
IMAGE_DATASET = [
|
||||
ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset, MathVision,
|
||||
MMMUDataset, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, TableVQABench,
|
||||
MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, MUIRDataset, CCOCRDataset,
|
||||
GMAIMMBenchDataset, MMERealWorld, HRBenchDataset, CRPE, MathVerse, NaturalBenchDataset,
|
||||
MIABench, OlympiadBench, WildVision, MMMath, QSpatial, Dynamath, MMGenBench, VizWiz, MMNIAH,
|
||||
CMMMU, VLRewardBench, WeMath, LogicVista
|
||||
]
|
||||
|
||||
VIDEO_DATASET = [
|
||||
MMBenchVideo, VideoMME, MVBench, MVBench_MP4, LongVideoBench,
|
||||
MLVU, MLVU_MCQ, MLVU_OpenEnded,
|
||||
TempCompass, TempCompass_MCQ, TempCompass_Captioning, TempCompass_YorN,
|
||||
CGBench_MCQ_Grounding_Mini, CGBench_OpenEnded_Mini, CGBench_MCQ_Grounding, CGBench_OpenEnded
|
||||
]
|
||||
|
||||
TEXT_DATASET = [
|
||||
TextMCQDataset
|
||||
]
|
||||
|
||||
CUSTOM_DATASET = [
|
||||
CustomMCQDataset, CustomVQADataset, CustomTextMCQDataset
|
||||
]
|
||||
|
||||
DATASET_COLLECTION = [ConcatDataset, ConcatVideoDataset]
|
||||
|
||||
DATASET_CLASSES = IMAGE_DATASET + VIDEO_DATASET + TEXT_DATASET + CUSTOM_DATASET + DATASET_COLLECTION
|
||||
SUPPORTED_DATASETS = []
|
||||
for DATASET_CLS in DATASET_CLASSES:
|
||||
SUPPORTED_DATASETS.extend(DATASET_CLS.supported_datasets())
|
||||
|
||||
|
||||
def DATASET_TYPE(dataset, *, default: str = 'MCQ') -> str:
|
||||
for cls in DATASET_CLASSES:
|
||||
if dataset in cls.supported_datasets():
|
||||
if hasattr(cls, 'TYPE'):
|
||||
return cls.TYPE
|
||||
# Have to add specific routine to handle ConcatDataset
|
||||
if dataset in ConcatDataset.DATASET_SETS:
|
||||
dataset_list = ConcatDataset.DATASET_SETS[dataset]
|
||||
TYPES = [DATASET_TYPE(dname) for dname in dataset_list]
|
||||
assert np.all([x == TYPES[0] for x in TYPES]), (dataset_list, TYPES)
|
||||
return TYPES[0]
|
||||
|
||||
if 'openended' in dataset.lower():
|
||||
return 'VQA'
|
||||
warnings.warn(f'Dataset {dataset} is a custom one and not annotated as `openended`, will treat as {default}. ')
|
||||
return default
|
||||
|
||||
|
||||
def DATASET_MODALITY(dataset, *, default: str = 'IMAGE') -> str:
|
||||
if dataset is None:
|
||||
warnings.warn(f'Dataset is not specified, will treat modality as {default}. ')
|
||||
return default
|
||||
for cls in DATASET_CLASSES:
|
||||
if dataset in cls.supported_datasets():
|
||||
if hasattr(cls, 'MODALITY'):
|
||||
return cls.MODALITY
|
||||
# Have to add specific routine to handle ConcatDataset
|
||||
if dataset in ConcatDataset.DATASET_SETS:
|
||||
dataset_list = ConcatDataset.DATASET_SETS[dataset]
|
||||
MODALITIES = [DATASET_MODALITY(dname) for dname in dataset_list]
|
||||
assert np.all([x == MODALITIES[0] for x in MODALITIES]), (dataset_list, MODALITIES)
|
||||
return MODALITIES[0]
|
||||
|
||||
if 'VIDEO' in dataset.lower():
|
||||
return 'VIDEO'
|
||||
elif 'IMAGE' in dataset.lower():
|
||||
return 'IMAGE'
|
||||
warnings.warn(f'Dataset {dataset} is a custom one, will treat modality as {default}. ')
|
||||
return default
|
||||
|
||||
|
||||
def build_dataset(dataset_name, **kwargs):
|
||||
for cls in DATASET_CLASSES:
|
||||
if dataset_name in supported_video_datasets:
|
||||
return supported_video_datasets[dataset_name](**kwargs)
|
||||
elif dataset_name in cls.supported_datasets():
|
||||
return cls(dataset=dataset_name, **kwargs)
|
||||
|
||||
warnings.warn(f'Dataset {dataset_name} is not officially supported. ')
|
||||
|
||||
data_file = osp.join(LMUDataRoot(), f'{dataset_name}.tsv')
|
||||
if not osp.exists(data_file):
|
||||
warnings.warn(f'Data file {data_file} does not exist. Dataset building failed. ')
|
||||
return None
|
||||
|
||||
data = load(data_file)
|
||||
if 'question' not in [x.lower() for x in data.columns]:
|
||||
warnings.warn(f'Data file {data_file} does not have a `question` column. Dataset building failed. ')
|
||||
return None
|
||||
|
||||
if 'A' in data and 'B' in data:
|
||||
if 'image' in data or 'image_path' in data:
|
||||
warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom MCQ dataset. ')
|
||||
return CustomMCQDataset(dataset=dataset_name, **kwargs)
|
||||
else:
|
||||
warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom Text MCQ dataset. ')
|
||||
return CustomTextMCQDataset(dataset=dataset_name, **kwargs)
|
||||
else:
|
||||
warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom VQA dataset. ')
|
||||
return CustomVQADataset(dataset=dataset_name, **kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'build_dataset', 'img_root_map', 'build_judge', 'extract_answer_from_item', 'prefetch_answer', 'DEBUG_MESSAGE'
|
||||
] + [cls.__name__ for cls in DATASET_CLASSES]
|
||||
1760
eval_mm/vlmevalkit/vlmeval/dataset/cgbench.py
Normal file
354
eval_mm/vlmevalkit/vlmeval/dataset/cmmmu.py
Normal file
@@ -0,0 +1,354 @@
|
||||
from .image_base import ImageBaseDataset
|
||||
import random
|
||||
from collections import Counter
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from ..smp import *
|
||||
|
||||
|
||||
def get_multi_choice_prediction(response, all_choices, index2ans):
|
||||
for char in [',', '.', '!', '?', ';', ':', "'"]:
|
||||
response = response.strip(char)
|
||||
response = " " + response + " " # add space to avoid partial match
|
||||
|
||||
candidates = []
|
||||
|
||||
for choice in all_choices: # (A) (B) (C) (D)
|
||||
# Add the choice to candidates each time it appears in the response
|
||||
candidates.extend([choice for _ in range(response.count(f'({choice})'))])
|
||||
|
||||
if len(candidates) == 0:
|
||||
for choice in all_choices: # A B C D
|
||||
# Similarly, add the choice for each occurrence
|
||||
candidates.extend([choice for _ in range(response.count(f'{choice}'))])
|
||||
|
||||
if len(candidates) == 0 and len(response.split()) >= 1:
|
||||
for index, ans in index2ans.items():
|
||||
# Add index for each occurrence of ans in response
|
||||
candidates.extend([index for _ in range(response.count(ans))])
|
||||
|
||||
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
|
||||
if len(candidates) == 0 and len(response.split()) >= 1:
|
||||
for index, ans in index2ans.items():
|
||||
if ans in response:
|
||||
candidates.append(index)
|
||||
# index_ans = False # it's content ans.
|
||||
|
||||
if len(candidates) == 0: # still not get answer, randomly choose one.
|
||||
return random.choice(all_choices)
|
||||
# return ''
|
||||
else:
|
||||
# Count the occurrence of each candidate
|
||||
candidate_counts = Counter(candidates)
|
||||
|
||||
# Select the most frequent candidates
|
||||
max_count = max(candidate_counts.values())
|
||||
most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
|
||||
|
||||
# Combine the most frequent candidates in ABCD order
|
||||
return ''.join(most_frequent_candidates)
|
||||
|
||||
|
||||
def extract_numbers(string):
|
||||
# Pattern for numbers with Chinese commas
|
||||
pattern_commas = r'-?\d{1,3}(?:,\d{3})+'
|
||||
# Pattern for scientific notation
|
||||
pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
|
||||
# Pattern for simple numbers without Chinese commas
|
||||
pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)'
|
||||
|
||||
# Extract numbers with Chinese commas
|
||||
numbers_with_commas = re.findall(pattern_commas, string)
|
||||
# Extract numbers in scientific notation
|
||||
numbers_scientific = re.findall(pattern_scientific, string)
|
||||
# Extract simple numbers without Chinese commas
|
||||
numbers_simple = re.findall(pattern_simple, string)
|
||||
|
||||
# Combine all extracted numbers
|
||||
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
|
||||
return all_numbers
|
||||
|
||||
|
||||
def check_is_number(string):
|
||||
try:
|
||||
float(string.replace(',', ''))
|
||||
return True
|
||||
except ValueError:
|
||||
# check if there's comma inside
|
||||
return False
|
||||
|
||||
|
||||
def count_letters(string):
|
||||
return sum(c.isalpha() and 'a' <= c <= 'z' or 'A' <= c <= 'Z' for c in string)
|
||||
|
||||
|
||||
def normalize_str(string, answer):
|
||||
# check if characters in the string
|
||||
|
||||
# if number, numerize it.
|
||||
if string is None:
|
||||
return [string]
|
||||
string = string.strip()
|
||||
|
||||
is_number = check_is_number(string)
|
||||
|
||||
if is_number:
|
||||
string = string.replace(',', '')
|
||||
string = float(string)
|
||||
# leave 2 decimal
|
||||
string = round(string, 2)
|
||||
return [string]
|
||||
else: # it's likely to be a string
|
||||
if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
|
||||
return []
|
||||
return [string]
|
||||
|
||||
|
||||
def get_fill_blank_prediction(response, answer):
|
||||
"""get the prediction from the generated response,
|
||||
return a list of predicted strings or numbers"""
|
||||
|
||||
def get_key_subresponses(response):
|
||||
response = response.strip("。").strip()
|
||||
sub_responses = re.split(r'。|\n', response)
|
||||
indicators_of_keys = ['是', '为', '所以', '等于', '方案', '选择',
|
||||
'正确答案', '因此', '最后', '答案', '结果']
|
||||
key_responses = []
|
||||
for index, resp in enumerate(sub_responses):
|
||||
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
|
||||
if index == len(sub_responses) - 1:
|
||||
indicators_of_keys.extend(['='])
|
||||
shortest_key_response = None
|
||||
# the shortest response that may contain the answer (tail part of the response)
|
||||
for indicator in indicators_of_keys:
|
||||
if indicator in resp:
|
||||
if not shortest_key_response:
|
||||
shortest_key_response = resp.split(indicator)[-1].strip()
|
||||
else:
|
||||
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
||||
shortest_key_response = resp.split(indicator)[-1].strip()
|
||||
|
||||
if shortest_key_response:
|
||||
# and it's not trivial
|
||||
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
||||
key_responses.append(shortest_key_response)
|
||||
if len(key_responses) == 0: # did not found any
|
||||
return [response]
|
||||
return key_responses
|
||||
|
||||
key_responses = get_key_subresponses(response)
|
||||
|
||||
pred_list = key_responses.copy() # keep the original string response
|
||||
for resp in key_responses:
|
||||
pred_list.extend(extract_numbers(resp))
|
||||
|
||||
tmp_pred_list = []
|
||||
for i in range(len(pred_list)):
|
||||
tmp_pred_list.extend(normalize_str(pred_list[i], answer))
|
||||
pred_list = tmp_pred_list
|
||||
|
||||
# remove duplicates
|
||||
pred_list = list(set(pred_list))
|
||||
|
||||
return pred_list
|
||||
|
||||
|
||||
def get_TF_prediction(response):
|
||||
"""get the prediction from the generated response,
|
||||
return a list of predicted strings or numbers"""
|
||||
|
||||
def get_key_subresponses(response):
|
||||
response = response.strip("。").strip()
|
||||
sub_responses = re.split(r'。|\n', response)
|
||||
indicators_of_keys = ['是', '为', '所以', '判断',
|
||||
'陈述', '说法', '表达', '答案', '结果']
|
||||
key_responses = []
|
||||
for index, resp in enumerate(sub_responses):
|
||||
shortest_key_response = None
|
||||
# the shortest response that may contain the answer (tail part of the response)
|
||||
for indicator in indicators_of_keys:
|
||||
if indicator in resp:
|
||||
if not shortest_key_response:
|
||||
shortest_key_response = resp.split(indicator)[-1].strip()
|
||||
else:
|
||||
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
||||
shortest_key_response = resp.split(indicator)[-1].strip()
|
||||
|
||||
if shortest_key_response:
|
||||
# and it's not trivial
|
||||
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
||||
key_responses.append(shortest_key_response)
|
||||
if len(key_responses) == 0: # did not found any
|
||||
return [response]
|
||||
return key_responses
|
||||
|
||||
key_responses = get_key_subresponses(response)
|
||||
|
||||
pred_list = key_responses.copy() # keep the original string response
|
||||
# remove duplicates
|
||||
pred_list = list(set(pred_list))
|
||||
|
||||
return pred_list
|
||||
|
||||
|
||||
class CMMMU(ImageBaseDataset):
|
||||
TYPE = 'VQA'
|
||||
|
||||
DATASET_URL = {
|
||||
'CMMMU_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/CMMMU_VAL.tsv'
|
||||
}
|
||||
|
||||
DATASET_MD5 = {
|
||||
'CMMMU_VAL': 'b4727e2fce2415bf646379e60c11a726'
|
||||
}
|
||||
|
||||
def dump_image(self, line):
|
||||
os.makedirs(self.img_root, exist_ok=True)
|
||||
|
||||
tgt_path_z = []
|
||||
if isinstance(line['image'], list):
|
||||
for i in range(len(line['image'])):
|
||||
tgt_path = osp.join(self.img_root, f"{line['index']}--{i + 1}.jpg")
|
||||
if not read_ok(tgt_path):
|
||||
decode_base64_to_image_file(line['image'][i], tgt_path)
|
||||
tgt_path_z.append(tgt_path)
|
||||
else:
|
||||
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
||||
if not read_ok(tgt_path):
|
||||
decode_base64_to_image_file(line['image'], tgt_path)
|
||||
tgt_path_z.append(tgt_path)
|
||||
return tgt_path_z
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
|
||||
if not osp.exists(result_file):
|
||||
data = load(eval_file)
|
||||
assert 'answer' in data and 'prediction' in data
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
data['answer'] = [str(x) for x in data['answer']]
|
||||
|
||||
correct_count = 0
|
||||
correct_category = {
|
||||
'技术与工程': [0, 0],
|
||||
'科学': [0, 0],
|
||||
'健康与医学': [0, 0],
|
||||
'商业': [0, 0],
|
||||
'艺术与设计': [0, 0],
|
||||
'人文社会科学': [0, 0],
|
||||
}
|
||||
|
||||
for i in tqdm(data.iterrows()):
|
||||
line = i[1]
|
||||
correct_category[line['category']][0] += 1
|
||||
|
||||
# Options
|
||||
if line['type'] == '选择':
|
||||
index2ans = {
|
||||
'A': line['option1'],
|
||||
'B': line['option2'],
|
||||
'C': line['option3'],
|
||||
'D': line['option4']
|
||||
}
|
||||
fact_option = get_multi_choice_prediction(line['prediction'], ['A', 'B', 'C', 'D'], index2ans)
|
||||
if fact_option == line['answer']:
|
||||
correct_count += 1
|
||||
correct_category[line['category']][1] += 1
|
||||
|
||||
# Binary
|
||||
elif line['type'] == '判断':
|
||||
positive_keywords = ['正确', '对', '准确', '肯定', '对的']
|
||||
negative_keywords = ['不对', '错误', '不正确', '不准确', '不合适', '否定', '错的', '错']
|
||||
ambiguous_keywords = ['对错', '是否正确', '否正确', '或者', '是否', '正确性', '对不']
|
||||
|
||||
def judge_similarity(pred_list, positive_keywords, negative_keywords):
|
||||
positive_count = 0
|
||||
negative_count = 0
|
||||
|
||||
for pred in pred_list:
|
||||
if any(pos_word in pred for pos_word in positive_keywords):
|
||||
positive_count += 1
|
||||
elif any(neg_word in pred for neg_word in negative_keywords):
|
||||
negative_count += 1
|
||||
|
||||
if positive_count > negative_count:
|
||||
return "对"
|
||||
elif negative_count > positive_count:
|
||||
return "错"
|
||||
else:
|
||||
return random.choice(['对', '错'])
|
||||
|
||||
answer = get_TF_prediction(line['prediction'])
|
||||
answer = [word for word in answer if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
|
||||
fact_answer = judge_similarity(answer, positive_keywords, negative_keywords)
|
||||
if fact_answer == line['answer']:
|
||||
correct_count += 1
|
||||
correct_category[line['category']][1] += 1
|
||||
|
||||
# Fill the Blank
|
||||
else:
|
||||
norm_answers = normalize_str(line['answer'], line['answer'])
|
||||
predicted_answer = get_fill_blank_prediction(line['prediction'], line['answer'])
|
||||
|
||||
for pred in predicted_answer:
|
||||
# already normalized
|
||||
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
|
||||
for norm_ans in norm_answers:
|
||||
# only see if the string answer in the string pred
|
||||
# print(norm_ans, pred)
|
||||
if isinstance(norm_ans, str) and norm_ans in pred:
|
||||
correct_count += 1
|
||||
correct_category[line['category']][1] += 1
|
||||
else: # it's a number
|
||||
if pred in norm_answers:
|
||||
correct_count += 1
|
||||
correct_category[line['category']][1] += 1
|
||||
|
||||
accuracyz = {}
|
||||
accuracyz['总准确率'] = correct_count / len(data)
|
||||
for i in correct_category.keys():
|
||||
accuracyz[i] = correct_category[i][1] / correct_category[i][0]
|
||||
|
||||
accuracyz = d2df(accuracyz)
|
||||
accuracyz.round(10)
|
||||
dump(accuracyz, result_file)
|
||||
|
||||
result = pd.read_csv(result_file)
|
||||
return result
|
||||
|
||||
def build_prompt(self, line):
|
||||
if line['type'] == '选择':
|
||||
tgt_path = self.dump_image(line)
|
||||
question = line['question']
|
||||
options_prompt = 'Options:\n'
|
||||
|
||||
for i in [['A', '1'], ['B', '2'], ['C', '3'], ['D', '4']]:
|
||||
options_prompt += i[0] + '. ' + line['option' + i[1]] + '\n'
|
||||
|
||||
prompt = (f'问题: {question}\n' + options_prompt
|
||||
+ '请回答上述多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。')
|
||||
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
|
||||
return msgs
|
||||
|
||||
elif line['type'] == '判断':
|
||||
msgs = super().build_prompt(line)
|
||||
assert msgs[-1]['type'] == 'text'
|
||||
msgs[-1]['value'] += '\n请回答上述判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。'
|
||||
return msgs
|
||||
|
||||
else:
|
||||
msgs = super().build_prompt(line)
|
||||
assert msgs[-1]['type'] == 'text'
|
||||
msgs[-1]['value'] += '\n请回答上述填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。'
|
||||
return msgs
|
||||
211
eval_mm/vlmevalkit/vlmeval/dataset/dude.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from .utils.judge_util import build_judge
|
||||
from .image_base import ImageBaseDataset
|
||||
from .mmlongbench import concat_images, MMLongBench_auxeval, anls_compute
|
||||
from ..smp import *
|
||||
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
def DUDE_acc(result_file):
|
||||
data = load(result_file)
|
||||
overall_score = 0.0
|
||||
score_list = list()
|
||||
for i in range(len(data)):
|
||||
item = data.iloc[i]
|
||||
if isinstance(item['answer'], float) and math.isnan(item['answer']):
|
||||
item['answer'] = 'Not answerable'
|
||||
|
||||
item['answer'] = item['answer'].lower()
|
||||
item['pred'] = item['pred'].lower()
|
||||
score = anls_compute(item['answer'], item['pred'])
|
||||
score_list.append(score)
|
||||
overall_score += score
|
||||
|
||||
data['score'] = score_list
|
||||
dump(data, result_file)
|
||||
|
||||
res = dict()
|
||||
res['category'], res['num'], res['avg_score'] = ['anls'], [len(data)], [overall_score / len(data)]
|
||||
res = pd.DataFrame(res)
|
||||
return res
|
||||
|
||||
|
||||
class DUDE(ImageBaseDataset):
|
||||
|
||||
TYPE = 'VQA'
|
||||
|
||||
DATASET_URL = {
|
||||
'DUDE': 'https://opencompass.openxlab.space/utils/VLMEval/DUDE.tsv',
|
||||
'DUDE_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/DUDE_MINI.tsv',
|
||||
}
|
||||
DATASET_MD5 = {
|
||||
'DUDE': '130d860d08206e1e407cd77150c10d88',
|
||||
'DUDE_MINI': 'e0c0d998114f0cca7516d12039d2b538',
|
||||
}
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
'GPT4': (1, 1),
|
||||
'GPT4V': (1, 1),
|
||||
'GPT4V_HIGH': (1, 1),
|
||||
'GPT4o': (1, 1),
|
||||
'GPT4o_HIGH': (1, 1),
|
||||
'GPT4o_MINI': (1, 1),
|
||||
'XComposer2d5': (1, -1),
|
||||
'XComposer2_4KHD': (1, -1),
|
||||
'MiniCPM-Llama3-V-2_5': (1, 5),
|
||||
'InternVL-Chat-V1-5': (5, 2),
|
||||
}
|
||||
|
||||
def __init__(self, dataset, **kwargs):
|
||||
self.model_list = list(self.SUPPORTED_MODELS.keys())
|
||||
model_name = kwargs['model']
|
||||
if not listinstr(self.model_list, model_name):
|
||||
raise AssertionError("{} doesn't support the evaluation on DUDE.".format(model_name))
|
||||
super(DUDE, self).__init__(dataset)
|
||||
|
||||
self.is_api = True if listinstr(['GPT4'], model_name) else False
|
||||
self.max_pages = 120
|
||||
concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
|
||||
self.concat_num = concat_num
|
||||
self.column_num = column_num
|
||||
|
||||
def prepare_tsv(self, url, file_md5=None):
|
||||
data_root = LMUDataRoot()
|
||||
os.makedirs(data_root, exist_ok=True)
|
||||
file_name = url.split('/')[-1]
|
||||
data_path = osp.join(data_root, file_name)
|
||||
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
|
||||
pass
|
||||
else:
|
||||
warnings.warn('The dataset tsv is not downloaded')
|
||||
download_file(url, data_path)
|
||||
return load(data_path)
|
||||
|
||||
def dump_image(self, origin_line):
|
||||
os.makedirs(self.img_root, exist_ok=True)
|
||||
try:
|
||||
import fitz
|
||||
except Exception as e:
|
||||
logging.critical(f'{type(e)}: {e}')
|
||||
logging.critical('Please use `pip install pymupdf` to parse PDF files.')
|
||||
|
||||
line = origin_line.copy()
|
||||
if not isinstance(line['image_path'], List):
|
||||
line['image_path'] = [line['image_path']]
|
||||
line['image_path'] = line['image_path'][:self.max_pages]
|
||||
skip_pdf_parse = True
|
||||
for im_name in line['image_path']:
|
||||
path = osp.join(self.img_root, im_name)
|
||||
if not read_ok(path):
|
||||
skip_pdf_parse = False
|
||||
break
|
||||
|
||||
# Just for being compatible with the zooped loop: zip(line['image'], line['image_path'])
|
||||
if skip_pdf_parse:
|
||||
line['image'] = line['image_path']
|
||||
else:
|
||||
pdf_data = base64.b64decode(line['image'])
|
||||
pdf_file = io.BytesIO(pdf_data)
|
||||
encoded_images = []
|
||||
with fitz.open(stream=pdf_file, filetype='pdf') as doc:
|
||||
doc = doc[:self.max_pages]
|
||||
for page in doc:
|
||||
image = page.get_pixmap(dpi=144)
|
||||
image_file = io.BytesIO(image.tobytes(output='png'))
|
||||
image = Image.open(image_file)
|
||||
encoded_image = encode_image_to_base64(image)
|
||||
encoded_images.append(encoded_image)
|
||||
line['image'] = encoded_images
|
||||
print('process {}'.format(line['doc_id']))
|
||||
|
||||
if 'image' in line:
|
||||
if isinstance(line['image'], list):
|
||||
tgt_path = []
|
||||
assert 'image_path' in line
|
||||
for img, im_name in zip(line['image'], line['image_path']):
|
||||
path = osp.join(self.img_root, im_name)
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(img, path)
|
||||
tgt_path.append(path)
|
||||
else:
|
||||
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
||||
if not read_ok(tgt_path):
|
||||
decode_base64_to_image_file(line['image'], tgt_path)
|
||||
tgt_path = [tgt_path]
|
||||
else:
|
||||
assert 'image_path' in line
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
|
||||
if self.concat_num > 0 and not self.is_api:
|
||||
concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
|
||||
|
||||
old_tgt_path = tgt_path
|
||||
assert isinstance(old_tgt_path, list)
|
||||
if self.column_num != -1:
|
||||
tgt_path = [
|
||||
'_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
|
||||
for i in range(len(concatenated_images))
|
||||
]
|
||||
else:
|
||||
tgt_path = ['_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all.jpg']
|
||||
|
||||
for path, concatenated_image in zip(tgt_path, concatenated_images):
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
|
||||
num_images, image_size = len(old_tgt_path), concatenated_image.size
|
||||
print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
|
||||
return tgt_path
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
logger = get_logger('Evaluation')
|
||||
model = judge_kwargs['model']
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
||||
|
||||
if osp.exists(storage):
|
||||
logger.warning(f'GPT scoring file {storage} already exists, will reuse it in DUDE_eval. ')
|
||||
else:
|
||||
data = load(eval_file)
|
||||
model = build_judge(max_tokens=128, **judge_kwargs)
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
tups = [(model, line) for line in lines]
|
||||
indices = [line['index'] for line in lines]
|
||||
|
||||
ans = {}
|
||||
if osp.exists(tmp_file):
|
||||
ans = load(tmp_file)
|
||||
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
||||
indices = [i for i in indices if i not in ans]
|
||||
|
||||
if len(indices):
|
||||
new_results = list()
|
||||
for model, line in tqdm(tups):
|
||||
res = MMLongBench_auxeval(model, line)
|
||||
new_results.append(res)
|
||||
|
||||
log_map, res_map, pred_map = {}, {}, {}
|
||||
all_inds = [line['index'] for line in lines]
|
||||
for k, v in zip(all_inds, new_results):
|
||||
log_map[k] = v['log']
|
||||
res_map[k] = v['res']
|
||||
pred_map[k] = v['pred']
|
||||
data['res'] = [res_map[idx] for idx in data['index']]
|
||||
data['log'] = [log_map[idx] for idx in data['index']]
|
||||
data['pred'] = [pred_map[idx] for idx in data['index']]
|
||||
dump(data, storage)
|
||||
|
||||
score = DUDE_acc(storage)
|
||||
score_pth = storage.replace('.xlsx', '_score.csv')
|
||||
|
||||
dump(score, score_pth)
|
||||
logger.info(f'DUDE successfully finished evaluating {eval_file}, results saved in {score_pth}')
|
||||
logger.info('Score: ')
|
||||
logger.info(score)
|
||||
240
eval_mm/vlmevalkit/vlmeval/dataset/dynamath.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import re
|
||||
import json
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sympy import simplify, Eq, sympify, Pow, pi
|
||||
from sympy.parsing.latex import parse_latex
|
||||
import sys
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
import argparse
|
||||
|
||||
from .image_base import ImageBaseDataset
|
||||
from .utils import build_judge
|
||||
from ..utils import track_progress_rich
|
||||
from ..smp import load, dump, d2df, toliststr
|
||||
|
||||
|
||||
def preprocess(str1):
|
||||
if 0 <= str1.find("{") < str1.rfind("}"):
|
||||
str1 = str1[str1.find("{"): str1.rfind("}") + 1]
|
||||
str2 = str1.replace("\\", "")
|
||||
str2 = str2.replace("\\n", "\n")
|
||||
return str2
|
||||
|
||||
|
||||
def transfer(str1):
|
||||
if "\u03c0" in str1:
|
||||
strs = str1.split('\u03c0')
|
||||
str1 = strs[0]
|
||||
return float(str1) * np.pi
|
||||
else:
|
||||
return float(str1)
|
||||
|
||||
|
||||
def parse_answer(answer, answer_type="multiple choice"):
|
||||
if answer_type == "float":
|
||||
if answer.isdigit():
|
||||
return True, float(answer)
|
||||
else:
|
||||
parts = answer.split(' ')
|
||||
answer = parts[0]
|
||||
try:
|
||||
answer = transfer(answer)
|
||||
return True, answer
|
||||
except:
|
||||
return False, None
|
||||
elif answer_type == "multiple choice":
|
||||
if len(answer) == 1:
|
||||
return True, answer.upper()
|
||||
else:
|
||||
in_flag = [ch in answer.upper() for ch in 'ABCDE']
|
||||
if sum(in_flag) == 1:
|
||||
for ch in 'ABCDE':
|
||||
if ch in answer.upper():
|
||||
return True, ch
|
||||
return False, None
|
||||
else:
|
||||
return True, answer
|
||||
|
||||
|
||||
def DynaMath_auxeval(model, line):
|
||||
pred = line['prediction']
|
||||
pred = preprocess(pred)
|
||||
|
||||
succeed, short_answer = None, None
|
||||
try:
|
||||
dj = json.loads(pred, strict=False)
|
||||
short_answer = dj.get("short answer")
|
||||
assert short_answer is not None
|
||||
succeed, short_answer = parse_answer(short_answer, answer_type=line['anwser_type'])
|
||||
assert succeed
|
||||
except:
|
||||
# Failed to parse the JSON, use an auxiliary LLM to get the short answer
|
||||
if line['answer_type'] == 'multiple choice':
|
||||
inst = "Output the corresponing choice option, such as 'A', 'B', 'C', 'D', in a single line."
|
||||
elif line['answer_type'] == 'float':
|
||||
inst = "Output a three-digit floating-point number in a single line."
|
||||
else:
|
||||
inst = (
|
||||
"Output a short answer in a single line. Any float numbers in the answer "
|
||||
"should be formatted as three-digit floating-point numbers."
|
||||
)
|
||||
|
||||
prompt = f"Free-form answer: {pred}\nInstruction: {inst}"
|
||||
response = pred
|
||||
succeed, short_answer = parse_answer(response, line['answer_type'])
|
||||
if not succeed:
|
||||
response = model.generate(prompt)
|
||||
succeed, short_answer = parse_answer(response, line['answer_type'])
|
||||
|
||||
if line['answer_type'] == 'float':
|
||||
if succeed:
|
||||
diff = float(short_answer) - float(line['answer'])
|
||||
if abs(diff) <= 0.001:
|
||||
return dict(parse=True, extracted=short_answer, correct=True)
|
||||
else:
|
||||
return dict(parse=True, extracted=short_answer, correct=False)
|
||||
else:
|
||||
return dict(parse=False, extracted=None, correct=False)
|
||||
elif line['answer_type'] == 'multiple choice':
|
||||
if succeed:
|
||||
return dict(parse=True, extracted=short_answer, correct=(short_answer == line['answer']))
|
||||
else:
|
||||
if line['answer'] in pred[:3].upper():
|
||||
return dict(parse=False, extracted=None, correct=True)
|
||||
else:
|
||||
return dict(parse=False, extracted=None, correct=False)
|
||||
else:
|
||||
if succeed:
|
||||
return dict(parse=True, extracted=short_answer, correct=(short_answer.lower() in line['answer'].lower()))
|
||||
else:
|
||||
return dict(parse=False, extracted=None, correct=(short_answer.lower() in line['answer'].lower()))
|
||||
|
||||
|
||||
class Dynamath(ImageBaseDataset):
|
||||
|
||||
TYPE = 'VQA'
|
||||
DATASET_URL = {'DynaMath': 'https://opencompass.openxlab.space/utils/VLMEval/DynaMath.tsv'}
|
||||
DATASET_MD5 = {'DynaMath': 'b8425ad9a7114571fc9366e013699494'}
|
||||
GUIDE = """
|
||||
## Answer Instruction Please provide an answer to the question outlined above. Your response should adhere \
|
||||
to the following JSON format, which includes two keys: 'solution' and 'short answer'. The 'solution' key can contain \
|
||||
detailed steps needed to solve the question, and the 'short answer' key should provide a concise response. {INST}
|
||||
|
||||
Example of expected JSON response format:
|
||||
|
||||
"""
|
||||
EXAMPLE = {
|
||||
"solution": "[Detailed step-by-step explanation]",
|
||||
"short answer": "[Concise Answer]"
|
||||
}
|
||||
TEXT_EXAMPLE = json.dumps(EXAMPLE, indent=4)
|
||||
|
||||
# Given one data record, return the built prompt (a multi-modal message), can override
|
||||
def build_prompt(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
else:
|
||||
tgt_path = self.dump_image(line)
|
||||
|
||||
prompt = f"## Question\n {line['question']}"
|
||||
if line['answer_type'] == 'multiple choice':
|
||||
inst = "Provide the corresponing choice option in the 'short answer' key, such as 'A', 'B', 'C', or 'D'."
|
||||
elif line['answer_type'] == 'float':
|
||||
inst = "Format the answer as a three-digit floating-point number and provide it in the 'short answer' key."
|
||||
else:
|
||||
inst = "Float numbers in the answer should be formatted as three-digit floating-point numbers."
|
||||
|
||||
prompt = prompt + self.GUIDE.format(INST=inst) + self.TEXT_EXAMPLE
|
||||
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
return msgs
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
judge_name = judge_kwargs.pop('model', 'gpt-4o-mini')
|
||||
|
||||
model = build_judge(model=judge_name, **judge_kwargs)
|
||||
suffix = eval_file.split('.')[-1]
|
||||
|
||||
storage = eval_file.replace(f'.{suffix}', f'_{judge_name}.xlsx') # noqa: F841
|
||||
score_file = eval_file.replace(f'.{suffix}', f'_{judge_name}_score.csv') # noqa: F841
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{judge_name}.pkl') # noqa: F841
|
||||
nproc = judge_kwargs.pop('nproc', 6) # noqa: F841
|
||||
|
||||
res = load(tmp_file) if os.path.exists(tmp_file) else {}
|
||||
res = {k: v for k, v in res.items() if v is not None}
|
||||
|
||||
model.system_prompt = """\
|
||||
You are a helpful assistant that helps me to format free-form answers into a short answer according to the instruction.
|
||||
"""
|
||||
if not osp.exists(storage):
|
||||
data = load(eval_file)
|
||||
lt = len(data)
|
||||
payloads = [dict(model=model, line=data.iloc[i]) for i in range(lt) if data.iloc[i]['index'] not in res]
|
||||
keys = [idx for idx in data['index'] if idx not in res]
|
||||
|
||||
if len(keys):
|
||||
results = track_progress_rich(DynaMath_auxeval, payloads, nproc=nproc, save=tmp_file, keys=keys)
|
||||
for k, r in zip(keys, results):
|
||||
res[k] = r
|
||||
|
||||
data['parse'] = [res[idx]['parse'] for idx in data['index']]
|
||||
data['extracted'] = [res[idx]['extracted'] for idx in data['index']]
|
||||
data['correct'] = [res[idx]['correct'] for idx in data['index']]
|
||||
dump(data, storage)
|
||||
|
||||
data = load(storage)
|
||||
# Calculate Average Accuracy
|
||||
score_avg = {}
|
||||
score_avg['Overall'] = np.mean(data['correct'])
|
||||
|
||||
subs = set(data['subject'])
|
||||
for sub in subs:
|
||||
data_sub = data[data['subject'] == sub]
|
||||
score_avg[f'Subject-{sub}'] = np.mean(data_sub['correct'])
|
||||
|
||||
lvls = set(data['knowledge_level'])
|
||||
for lvl in lvls:
|
||||
data_lvl = data[data['knowledge_level'] == lvl]
|
||||
score_avg[f'Level-{lvl}'] = np.mean(data_lvl['correct'])
|
||||
|
||||
# Calculate the Worst Case Accuracy
|
||||
score_worst = {}
|
||||
data_worst = data[data['varid'] == 1]
|
||||
qid2corr = {idx: True for idx in data_worst['index']}
|
||||
lt = len(data)
|
||||
for i in range(lt):
|
||||
item = data.iloc[i]
|
||||
qid2corr[item['qid']] *= item['correct']
|
||||
data_worst['correct'] = [qid2corr[idx] for idx in data_worst['qid']]
|
||||
score_worst['Overall'] = np.mean(data_worst['correct'])
|
||||
|
||||
subs = set(data_worst['subject'])
|
||||
for sub in subs:
|
||||
data_sub = data_worst[data_worst['subject'] == sub]
|
||||
score_worst[f'Subject-{sub}'] = np.mean(data_sub['correct'])
|
||||
|
||||
lvls = set(data_worst['knowledge_level'])
|
||||
for lvl in lvls:
|
||||
data_lvl = data_worst[data_worst['knowledge_level'] == lvl]
|
||||
score_worst[f'Level-{lvl}'] = np.mean(data_lvl['correct'])
|
||||
|
||||
d1 = {'Setting': 'Average'}
|
||||
d1.update(score_avg)
|
||||
d2 = {'Setting': 'Worst Case'}
|
||||
d2.update(score_worst)
|
||||
score = pd.concat([d2df(d1), d2df(d2)], ignore_index=True)
|
||||
|
||||
dump(score, score_file)
|
||||
return score
|
||||
172
eval_mm/vlmevalkit/vlmeval/dataset/image_base.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import pandas as pd
|
||||
from abc import abstractmethod
|
||||
from ..smp import *
|
||||
|
||||
|
||||
def img_root_map(dataset):
|
||||
if 'MM_NIAH' in dataset:
|
||||
return 'MMNIAH'
|
||||
if 'CRPE' in dataset:
|
||||
return 'CRPE'
|
||||
if 'OCRVQA' in dataset:
|
||||
return 'OCRVQA'
|
||||
if 'COCO_VAL' == dataset:
|
||||
return 'COCO'
|
||||
if 'MMMU' in dataset:
|
||||
return 'MMMU'
|
||||
if "QSpatial" in dataset:
|
||||
return "QSpatial"
|
||||
|
||||
mmbench_root_map = {
|
||||
'MMBench_DEV_EN': 'MMBench', 'MMBench_TEST_EN': 'MMBench',
|
||||
'MMBench_DEV_CN': 'MMBench', 'MMBench_TEST_CN': 'MMBench',
|
||||
'MMBench': 'MMBench', 'MMBench_CN': 'MMBench',
|
||||
'MMBench_DEV_EN_V11': 'MMBench_V11', 'MMBench_TEST_EN_V11': 'MMBench_V11',
|
||||
'MMBench_DEV_CN_V11': 'MMBench_V11', 'MMBench_TEST_CN_V11': 'MMBench_V11',
|
||||
'MMBench_V11': 'MMBench', 'MMBench_CN_V11': 'MMBench',
|
||||
}
|
||||
if dataset in mmbench_root_map:
|
||||
return mmbench_root_map[dataset]
|
||||
return dataset
|
||||
|
||||
|
||||
class ImageBaseDataset:
|
||||
|
||||
MODALITY = 'IMAGE'
|
||||
DATASET_URL = {}
|
||||
DATASET_MD5 = {}
|
||||
|
||||
def __init__(self, dataset='MMBench', skip_noimg=True):
|
||||
ROOT = LMUDataRoot()
|
||||
# You can override this variable to save image files to a different directory
|
||||
self.dataset_name = dataset
|
||||
self.img_root = osp.join(ROOT, 'images', img_root_map(dataset))
|
||||
|
||||
data = self.load_data(dataset)
|
||||
self.skip_noimg = skip_noimg
|
||||
if skip_noimg and 'image' in data:
|
||||
data = data[~pd.isna(data['image'])]
|
||||
|
||||
data['index'] = [str(x) for x in data['index']]
|
||||
|
||||
self.meta_only = True
|
||||
|
||||
# The image field can store the base64 encoded image or another question index (for saving space)
|
||||
if 'image' in data:
|
||||
data['image'] = [str(x) for x in data['image']]
|
||||
image_map = {x: y for x, y in zip(data['index'], data['image'])}
|
||||
for k in image_map:
|
||||
if len(image_map[k]) <= 64:
|
||||
idx = image_map[k]
|
||||
assert idx in image_map and len(image_map[idx]) > 64
|
||||
image_map[k] = image_map[idx]
|
||||
|
||||
images = [toliststr(image_map[k]) for k in data['index']]
|
||||
data['image'] = [x[0] if len(x) == 1 else x for x in images]
|
||||
self.meta_only = False
|
||||
|
||||
if 'image_path' in data:
|
||||
paths = [toliststr(x) for x in data['image_path']]
|
||||
data['image_path'] = [x[0] if len(x) == 1 else x for x in paths]
|
||||
|
||||
if np.all([istype(x, int) for x in data['index']]):
|
||||
data['index'] = [int(x) for x in data['index']]
|
||||
|
||||
self.data = data
|
||||
self.post_build(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(self.data.iloc[idx])
|
||||
|
||||
def prepare_tsv(self, url, file_md5=None):
|
||||
data_root = LMUDataRoot()
|
||||
os.makedirs(data_root, exist_ok=True)
|
||||
update_flag = False
|
||||
file_name = url.split('/')[-1]
|
||||
data_path = osp.join(data_root, file_name)
|
||||
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
|
||||
pass
|
||||
else:
|
||||
warnings.warn('The dataset tsv is not downloaded')
|
||||
download_file(url, data_path)
|
||||
update_flag = True
|
||||
|
||||
if file_size(data_path, 'GB') > 1:
|
||||
local_path = data_path.replace('.tsv', '_local.tsv')
|
||||
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
|
||||
from ..tools import LOCALIZE
|
||||
LOCALIZE(data_path, local_path)
|
||||
data_path = local_path
|
||||
return load(data_path)
|
||||
|
||||
def dump_image(self, line):
|
||||
os.makedirs(self.img_root, exist_ok=True)
|
||||
|
||||
if 'image' in line:
|
||||
if isinstance(line['image'], list):
|
||||
tgt_path = []
|
||||
assert 'image_path' in line
|
||||
for img, im_name in zip(line['image'], line['image_path']):
|
||||
path = osp.join(self.img_root, im_name)
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(img, path)
|
||||
tgt_path.append(path)
|
||||
else:
|
||||
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
||||
if not read_ok(tgt_path):
|
||||
decode_base64_to_image_file(line['image'], tgt_path)
|
||||
tgt_path = [tgt_path]
|
||||
else:
|
||||
assert 'image_path' in line
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
|
||||
return tgt_path
|
||||
|
||||
def display(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
assert isinstance(line, pd.Series) or isinstance(line, dict)
|
||||
mmqa_display(line)
|
||||
|
||||
# Return a list of dataset names that are supported by this class, can override
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return list(cls.DATASET_URL)
|
||||
|
||||
# Given the dataset name, return the dataset as a pandas dataframe, can override
|
||||
def load_data(self, dataset):
|
||||
url = self.DATASET_URL[dataset]
|
||||
file_md5 = self.DATASET_MD5[dataset] if dataset in self.DATASET_MD5 else None
|
||||
return self.prepare_tsv(url, file_md5)
|
||||
|
||||
# Post built hook, will be called after the dataset is built, can override
|
||||
def post_build(self, dataset):
|
||||
pass
|
||||
|
||||
# Given one data record, return the built prompt (a multi-modal message), can override
|
||||
def build_prompt(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
else:
|
||||
tgt_path = self.dump_image(line)
|
||||
|
||||
question = line['question']
|
||||
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=question))
|
||||
return msgs
|
||||
|
||||
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
|
||||
@abstractmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
pass
|
||||
75
eval_mm/vlmevalkit/vlmeval/dataset/image_caption.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from .image_base import ImageBaseDataset
|
||||
from ..smp import *
|
||||
|
||||
|
||||
class COCO_Caption_Scorer():
|
||||
def __init__(self, ref, gt):
|
||||
from pycocoevalcap.bleu.bleu import Bleu
|
||||
from pycocoevalcap.rouge.rouge import Rouge
|
||||
from pycocoevalcap.cider.cider import Cider
|
||||
|
||||
self.ref = ref
|
||||
self.gt = gt
|
||||
print('setting up scorers...')
|
||||
self.scorers = [
|
||||
(Bleu(4), ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4']),
|
||||
(Rouge(), 'ROUGE_L'),
|
||||
(Cider(), 'CIDEr'),
|
||||
]
|
||||
|
||||
def compute_scores(self):
|
||||
total_scores = {}
|
||||
for scorer, method in self.scorers:
|
||||
print('computing %s score...' % (scorer.method()))
|
||||
score, scores = scorer.compute_score(self.gt, self.ref)
|
||||
if isinstance(method, list):
|
||||
for sc, scs, m in zip(score, scores, method):
|
||||
print('%s: %0.3f' % (m, sc * 100))
|
||||
total_scores['Bleu'] = [x * 100 for x in score]
|
||||
else:
|
||||
print('%s: %0.3f' % (method, score * 100))
|
||||
total_scores[method] = score * 100
|
||||
|
||||
print('*****DONE*****')
|
||||
for key, value in total_scores.items():
|
||||
print('{}:{}'.format(key, value))
|
||||
return total_scores
|
||||
|
||||
|
||||
class ImageCaptionDataset(ImageBaseDataset):
|
||||
|
||||
TYPE = 'Caption'
|
||||
|
||||
DATASET_URL = {
|
||||
'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
|
||||
}
|
||||
|
||||
DATASET_MD5 = {
|
||||
'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
|
||||
}
|
||||
|
||||
def load_data(self, dataset):
|
||||
data = super().load_data(dataset)
|
||||
if 'question' not in data:
|
||||
data['question'] = [(
|
||||
'Please describe this image in general. Directly provide the description, '
|
||||
'do not include prefix like "This image depicts". '
|
||||
)] * len(data)
|
||||
return data
|
||||
|
||||
# It returns a dictionary of scores
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **kwargs):
|
||||
data = load(eval_file)
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
ref, gt = {}, {}
|
||||
for i, line in enumerate(lines):
|
||||
ref[str(i)] = [str(line['prediction'])]
|
||||
gt[str(i)] = eval(line['answer'])
|
||||
|
||||
scorer = COCO_Caption_Scorer(ref, gt)
|
||||
coco_caption_score_dict = scorer.compute_scores()
|
||||
score_pth = eval_file.replace('.xlsx', '_score.json')
|
||||
dump(coco_caption_score_dict, score_pth)
|
||||
return coco_caption_score_dict
|
||||
197
eval_mm/vlmevalkit/vlmeval/dataset/image_ccocr.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# flake8: noqa
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from functools import partial
|
||||
import pandas as pd
|
||||
|
||||
from .image_base import ImageBaseDataset
|
||||
from ..smp import *
|
||||
|
||||
# should be the same as FAIL_MSG definded in vlmeval/inference.py
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
class CCOCRDataset(ImageBaseDataset):
|
||||
TYPE = 'VQA'
|
||||
DATASET_URL_MODELSCOPE = {
|
||||
"CCOCR_DocParsing_DocPhotoChn": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/doc/doc_photo_chn_75.tsv",
|
||||
"CCOCR_DocParsing_DocPhotoEng": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/doc/doc_photo_eng_75.tsv",
|
||||
"CCOCR_DocParsing_DocScanChn": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/doc/doc_scan_chn_75.tsv",
|
||||
"CCOCR_DocParsing_DocScanEng": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/doc/doc_scan_eng_75.tsv",
|
||||
"CCOCR_DocParsing_TablePhotoChn": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/table/table_photo_chn_75.tsv",
|
||||
"CCOCR_DocParsing_TablePhotoEng": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/table/table_photo_eng_75.tsv",
|
||||
"CCOCR_DocParsing_TableScanChn": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/table/table_scan_chn_75.tsv",
|
||||
"CCOCR_DocParsing_TableScanEng": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/table/table_scan_eng_75.tsv",
|
||||
"CCOCR_DocParsing_MolecularHandwriting": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/molecular/molecular_handwriting_100.tsv",
|
||||
"CCOCR_DocParsing_FormulaHandwriting": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/doc_parsing/formula/formula_handwriting_100.tsv",
|
||||
"CCOCR_Kie_Sroie2019Word": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/kie/constrained_category/sroie2019_word_347.tsv",
|
||||
"CCOCR_Kie_Cord": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/kie/constrained_category/CORD_100.tsv",
|
||||
"CCOCR_Kie_EphoieScut": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/kie/constrained_category/EPHOIE_SCUT_311.tsv",
|
||||
"CCOCR_Kie_Poie": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/kie/constrained_category/POIE_250.tsv",
|
||||
"CCOCR_Kie_ColdSibr": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/kie/open_category/COLD_SIBR_400.tsv",
|
||||
"CCOCR_Kie_ColdCell": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/kie/open_category/COLD_CELL_600.tsv",
|
||||
"CCOCR_MultiLanOcr_Arabic": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/Arabic/Arabic_150.tsv",
|
||||
"CCOCR_MultiLanOcr_French": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/French/French_150.tsv",
|
||||
"CCOCR_MultiLanOcr_German": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/German/German_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Italian": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/Italian/Italian_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Japanese": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/Japanese/Japanese_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Korean": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/Korean/Korean_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Portuguese": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/Portuguese/Portuguese_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Russian": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/Russian/Russian_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Spanish": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/Spanish/Spanish_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Vietnamese": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_lan_ocr/Vietnamese/Vietnamese_150.tsv",
|
||||
"CCOCR_MultiSceneOcr_Cord": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/document_text/CORD_100.tsv",
|
||||
"CCOCR_MultiSceneOcr_Funsd": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/document_text/FUNSD_50.tsv",
|
||||
"CCOCR_MultiSceneOcr_Iam": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/document_text/IAM_50.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhDoc": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/document_text/zh_doc_100.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhHandwriting": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/document_text/zh_handwriting_50.tsv",
|
||||
"CCOCR_MultiSceneOcr_Hieragent": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/scene_text/Hieragent_100.tsv",
|
||||
"CCOCR_MultiSceneOcr_Ic15": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/scene_text/IC15_500.tsv",
|
||||
"CCOCR_MultiSceneOcr_Inversetext": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/scene_text/InverseText_500.tsv",
|
||||
"CCOCR_MultiSceneOcr_Totaltext": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/scene_text/TotalText_300.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhScene": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/scene_text/zh_scene_450.tsv",
|
||||
"CCOCR_MultiSceneOcr_UgcLaion": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/ugc_text/ugc_laion_400.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhDense": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/ugc_text/zh_dense_50.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhVertical": "https://www.modelscope.cn/datasets/Qwen/CC-OCR/resolve/master/multi_scene_ocr/ugc_text/zh_vertical_100.tsv"
|
||||
}
|
||||
|
||||
DATASET_URL_HUGGINGFACE = {
|
||||
"CCOCR_DocParsing_DocPhotoChn": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/doc/doc_photo_chn_75.tsv",
|
||||
"CCOCR_DocParsing_DocPhotoEng": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/doc/doc_photo_eng_75.tsv",
|
||||
"CCOCR_DocParsing_DocScanChn": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/doc/doc_scan_chn_75.tsv",
|
||||
"CCOCR_DocParsing_DocScanEng": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/doc/doc_scan_eng_75.tsv",
|
||||
"CCOCR_DocParsing_TablePhotoChn": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/table/table_photo_chn_75.tsv",
|
||||
"CCOCR_DocParsing_TablePhotoEng": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/table/table_photo_eng_75.tsv",
|
||||
"CCOCR_DocParsing_TableScanChn": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/table/table_scan_chn_75.tsv",
|
||||
"CCOCR_DocParsing_TableScanEng": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/table/table_scan_eng_75.tsv",
|
||||
"CCOCR_DocParsing_MolecularHandwriting": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/molecular/molecular_handwriting_100.tsv",
|
||||
"CCOCR_DocParsing_FormulaHandwriting": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/doc_parsing/formula/formula_handwriting_100.tsv",
|
||||
"CCOCR_Kie_Sroie2019Word": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/kie/constrained_category/sroie2019_word_347.tsv",
|
||||
"CCOCR_Kie_Cord": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/kie/constrained_category/CORD_100.tsv",
|
||||
"CCOCR_Kie_EphoieScut": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/kie/constrained_category/EPHOIE_SCUT_311.tsv",
|
||||
"CCOCR_Kie_Poie": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/kie/constrained_category/POIE_250.tsv",
|
||||
"CCOCR_Kie_ColdSibr": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/kie/open_category/COLD_SIBR_400.tsv",
|
||||
"CCOCR_Kie_ColdCell": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/kie/open_category/COLD_CELL_600.tsv",
|
||||
"CCOCR_MultiLanOcr_Arabic": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/Arabic/Arabic_150.tsv",
|
||||
"CCOCR_MultiLanOcr_French": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/French/French_150.tsv",
|
||||
"CCOCR_MultiLanOcr_German": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/German/German_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Italian": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/Italian/Italian_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Japanese": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/Japanese/Japanese_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Korean": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/Korean/Korean_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Portuguese": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/Portuguese/Portuguese_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Russian": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/Russian/Russian_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Spanish": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/Spanish/Spanish_150.tsv",
|
||||
"CCOCR_MultiLanOcr_Vietnamese": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_lan_ocr/Vietnamese/Vietnamese_150.tsv",
|
||||
"CCOCR_MultiSceneOcr_Cord": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/document_text/CORD_100.tsv",
|
||||
"CCOCR_MultiSceneOcr_Funsd": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/document_text/FUNSD_50.tsv",
|
||||
"CCOCR_MultiSceneOcr_Iam": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/document_text/IAM_50.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhDoc": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/document_text/zh_doc_100.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhHandwriting": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/document_text/zh_handwriting_50.tsv",
|
||||
"CCOCR_MultiSceneOcr_Hieragent": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/scene_text/Hieragent_100.tsv",
|
||||
"CCOCR_MultiSceneOcr_Ic15": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/scene_text/IC15_500.tsv",
|
||||
"CCOCR_MultiSceneOcr_Inversetext": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/scene_text/InverseText_500.tsv",
|
||||
"CCOCR_MultiSceneOcr_Totaltext": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/scene_text/TotalText_300.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhScene": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/scene_text/zh_scene_450.tsv",
|
||||
"CCOCR_MultiSceneOcr_UgcLaion": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/ugc_text/ugc_laion_400.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhDense": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/ugc_text/zh_dense_50.tsv",
|
||||
"CCOCR_MultiSceneOcr_ZhVertical": "https://huggingface.co/datasets/wulipc/CC-OCR/resolve/main/multi_scene_ocr/ugc_text/zh_vertical_100.tsv"
|
||||
}
|
||||
|
||||
# define data path
|
||||
DATASET_URL = DATASET_URL_MODELSCOPE
|
||||
DATASET_MD5 = {
|
||||
"CCOCR_DocParsing_DocPhotoChn": "9039dcbb31830d413261a95cfa29d97f",
|
||||
"CCOCR_DocParsing_DocPhotoEng": "2ca0824881e1d7317626f2a19d902989",
|
||||
"CCOCR_DocParsing_DocScanChn": "9e265c8aa760ebdf5c3bf9e892d55492",
|
||||
"CCOCR_DocParsing_DocScanEng": "77d04637be3def86dbc2ce37ba64a704",
|
||||
"CCOCR_DocParsing_TablePhotoChn": "c4dc85252ddad2b43a03a67b1d1ae983",
|
||||
"CCOCR_DocParsing_TablePhotoEng": "02ab75d6169da0cd2ece9ce0ae14a479",
|
||||
"CCOCR_DocParsing_TableScanChn": "f1f79959fdd01127df7377c9d46722f2",
|
||||
"CCOCR_DocParsing_TableScanEng": "794903c7acf52bfe956eefba2166d14b",
|
||||
"CCOCR_DocParsing_MolecularHandwriting": "30b7f7679b713ce000a939eca7b4078f",
|
||||
"CCOCR_DocParsing_FormulaHandwriting": "e03047776ce5e79a61ae1c057e2a348e",
|
||||
"CCOCR_Kie_Sroie2019Word": "3287d99a8e86a99b74171fa5a70f9acb",
|
||||
"CCOCR_Kie_Cord": "ab297cadcbc7158884a301c366f3330a",
|
||||
"CCOCR_Kie_EphoieScut": "bb8fa3ba7ea91cbf17be0904956ad3f3",
|
||||
"CCOCR_Kie_Poie": "882b64317989ecbfed6518051cdffb14",
|
||||
"CCOCR_Kie_ColdSibr": "109d5dad8b7081fb6a2f088e963196d4",
|
||||
"CCOCR_Kie_ColdCell": "7b44c45b4d7d768d1dbdc08872fe7d3a",
|
||||
"CCOCR_MultiLanOcr_Arabic": "e9a3f2bb9298d0b882ebc7a98980c3f3",
|
||||
"CCOCR_MultiLanOcr_French": "729407ed2036c22e602eff645eddd40c",
|
||||
"CCOCR_MultiLanOcr_German": "96fc2edae747f0ec95b0a6f9bf723022",
|
||||
"CCOCR_MultiLanOcr_Italian": "29a508fa5d5a5e767497dd69e2430ebb",
|
||||
"CCOCR_MultiLanOcr_Japanese": "bbcca96ccf25fff63597c2ab4f3ebb1f",
|
||||
"CCOCR_MultiLanOcr_Korean": "0f55dbd24eba5edc189c91e124411641",
|
||||
"CCOCR_MultiLanOcr_Portuguese": "a6fcf8831775a61aa631c0cf1c422ae7",
|
||||
"CCOCR_MultiLanOcr_Russian": "19d2f84062a1699d3e9333912bd6b303",
|
||||
"CCOCR_MultiLanOcr_Spanish": "f5a0cfa9f2ae4115c91c7b362034e591",
|
||||
"CCOCR_MultiLanOcr_Vietnamese": "bf1cd4e83d91767f4906f81550cec8b9",
|
||||
"CCOCR_MultiSceneOcr_Cord": "92943f0ccb4c5a196c574222e76759a0",
|
||||
"CCOCR_MultiSceneOcr_Funsd": "229cc38d193edd00f4383610e98ee873",
|
||||
"CCOCR_MultiSceneOcr_Iam": "d897a6d6c3880c65e752ec11b211204c",
|
||||
"CCOCR_MultiSceneOcr_ZhDoc": "303682cc16c8bb51b2b896f8ceb8bd38",
|
||||
"CCOCR_MultiSceneOcr_ZhHandwriting": "faa298d366bc05e5cfb39e334afb8eff",
|
||||
"CCOCR_MultiSceneOcr_Hieragent": "6f132cdd0473d7cc145c3e3a08957dd6",
|
||||
"CCOCR_MultiSceneOcr_Ic15": "3d94869f312a41d53d0578a06a2fb1f2",
|
||||
"CCOCR_MultiSceneOcr_Inversetext": "e141d424a0c4cf9579064428a270f13d",
|
||||
"CCOCR_MultiSceneOcr_Totaltext": "ca1daf81d49eeb57ef844b72a23c2e62",
|
||||
"CCOCR_MultiSceneOcr_ZhScene": "9295152a66e6f117db8bfbb20a9013e6",
|
||||
"CCOCR_MultiSceneOcr_UgcLaion": "8e9ea1fbf9d56532157e807eabf39b21",
|
||||
"CCOCR_MultiSceneOcr_ZhDense": "de8f48ee0c8a2cf8ed7f2b3a81e6322d",
|
||||
"CCOCR_MultiSceneOcr_ZhVertical": "4892b4aec6e7fd11e39aaea23712709b"
|
||||
}
|
||||
|
||||
# It returns a DataFrame
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
"""
|
||||
"""
|
||||
df = load(eval_file)
|
||||
dict_list = df.to_dict(orient='records')
|
||||
|
||||
required_colume_list = ['answer', 'prediction', "category", "image_name", "l2-category", "split"]
|
||||
for required_colume in required_colume_list:
|
||||
assert required_colume in df, "required_colume: {} NOT found".format(required_colume)
|
||||
|
||||
gt_info, ptd_info = {}, {}
|
||||
for data_info in dict_list:
|
||||
image_name = data_info['image_name']
|
||||
gt_info[image_name] = data_info['answer']
|
||||
|
||||
# warning the FAIL samples
|
||||
if data_info['prediction'] != FAIL_MSG:
|
||||
ptd_info[image_name] = data_info['prediction']
|
||||
|
||||
# assert eval_file is a single dataset
|
||||
group_name = set([str(x) for x in df['category']]).pop()
|
||||
op_name = set([str(x) for x in df['l2-category']]).pop()
|
||||
data_name = set([str(x) for x in df['split']]).pop()
|
||||
|
||||
data_info = {"op": op_name, "group": group_name, "dataset": data_name, "num": len(gt_info)}
|
||||
try:
|
||||
from .utils.ccocr_evaluator import evaluator_map_info as ccocr_evaluator_map
|
||||
except ImportError as err:
|
||||
import warnings
|
||||
warnings.warn('The dependency of CCOCR evaluator is not properly installed')
|
||||
warnings.warn(f'{type(err)}: {err}')
|
||||
eval_func = ccocr_evaluator_map.get(group_name, None)
|
||||
if eval_func is None:
|
||||
raise ValueError("error: evaluator not defined for: {}".format(group_name))
|
||||
meta_info, eval_info = eval_func(ptd_info, gt_info, **data_info)
|
||||
|
||||
output_info = {"meta": meta_info, "evaluation": eval_info, "config": data_info}
|
||||
result_file = os.path.splitext(os.path.abspath(eval_file))[0] + "_eval.json"
|
||||
dump(output_info, result_file)
|
||||
|
||||
# update global status for summary
|
||||
# warning: the evaluate function should NOT run in parallel
|
||||
all_status_info = {}
|
||||
global_status_path = os.path.join(os.path.dirname(eval_file), "status.json")
|
||||
if os.path.exists(global_status_path):
|
||||
with open(global_status_path, "r") as f:
|
||||
all_status_info = json.load(f)
|
||||
all_status_info[data_name] = output_info
|
||||
with open(global_status_path, "w") as f:
|
||||
json.dump(all_status_info, f, ensure_ascii=False, indent=4)
|
||||
return eval_info.get("summary")
|
||||
904
eval_mm/vlmevalkit/vlmeval/dataset/image_mcq.py
Normal file
@@ -0,0 +1,904 @@
|
||||
import warnings
|
||||
|
||||
from .image_base import ImageBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from ..smp import *
|
||||
import pandas as pd
|
||||
|
||||
MMMB_URLS = {
|
||||
'MMMB_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ar.tsv',
|
||||
'MMMB_cn': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_cn.tsv',
|
||||
'MMMB_en': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_en.tsv',
|
||||
'MMMB_pt': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_pt.tsv',
|
||||
'MMMB_ru': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ru.tsv',
|
||||
'MMMB_tr': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_tr.tsv',
|
||||
}
|
||||
|
||||
MTL_MMBench_URLS = {
|
||||
'MMBench_dev_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_ar.tsv',
|
||||
'MMBench_dev_cn': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_cn.tsv',
|
||||
'MMBench_dev_en': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_en.tsv',
|
||||
'MMBench_dev_pt': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_pt.tsv',
|
||||
'MMBench_dev_tr': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_tr.tsv',
|
||||
'MMBench_dev_ru': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_ru.tsv',
|
||||
}
|
||||
|
||||
MMMB_MD5 = {
|
||||
'MMMB_ar': 'f3a18b6385f1d9701840aa42de27aead', 'MMMB_cn': '13ed82fa89730037292fcaa27f08f430',
|
||||
'MMMB_en': '1cd781a71ec5a2983c090b84105d6a01', 'MMMB_pt': '548ea2b3bb2da991790386f0015d30d1',
|
||||
'MMMB_ru': 'ce1cc8a0533425ab0d86b326ebfc2984', 'MMMB_tr': '0733739d43090327975294292bc5cd67'
|
||||
}
|
||||
|
||||
MTL_MMBench_MD5 = {
|
||||
'MMBench_dev_ar': '4271b4a0d0200e1a86380a878e0d64a4', 'MMBench_dev_cn': '2ed5135326fed02c8e51ea50dda8222f',
|
||||
'MMBench_dev_en': 'd9ab776fc018b3d45785e9a5c23431c2', 'MMBench_dev_pt': '4ddfbcd27ef12444b908c03831cd0295',
|
||||
'MMBench_dev_tr': '4fab39d501389d3d6cc90264bb708f11', 'MMBench_dev_ru': '5ba1171ff2e68f80637bf78349e402a5'
|
||||
}
|
||||
|
||||
|
||||
class ImageMCQDataset(ImageBaseDataset):
|
||||
|
||||
TYPE = 'MCQ'
|
||||
|
||||
DATASET_URL = {
|
||||
# MMBench v1.0
|
||||
'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN.tsv',
|
||||
'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_EN.tsv',
|
||||
'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_CN.tsv',
|
||||
'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_CN.tsv',
|
||||
'MMBench': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench.tsv', # Internal
|
||||
'MMBench_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_CN.tsv', # Internal
|
||||
# MMBench v1.1
|
||||
'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN_V11.tsv',
|
||||
'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_EN_V11.tsv',
|
||||
'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_CN_V11.tsv',
|
||||
'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_CN_V11.tsv',
|
||||
'MMBench_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_V11.tsv', # Internal
|
||||
'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_CN_V11.tsv', # Internal
|
||||
# SEEDBench Series
|
||||
'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench_IMG.tsv',
|
||||
'SEEDBench2': 'https://huggingface.co/datasets/VLMEval/SEEDBench2/resolve/main/SEEDBench2.tsv',
|
||||
'SEEDBench2_Plus': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench2_Plus.tsv',
|
||||
# ScienceQA Series
|
||||
'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_VAL.tsv',
|
||||
'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_TEST.tsv',
|
||||
# MMT-Bench
|
||||
'MMT-Bench_ALL_MI': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_ALL_MI.tsv',
|
||||
'MMT-Bench_ALL': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_ALL.tsv',
|
||||
'MMT-Bench_VAL_MI': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_VAL_MI.tsv',
|
||||
'MMT-Bench_VAL': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_VAL.tsv',
|
||||
# AesBench
|
||||
'AesBench_VAL': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_VAL.tsv',
|
||||
'AesBench_TEST': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_TEST.tsv',
|
||||
# Q-Bench1
|
||||
'Q-Bench1_VAL': 'https://huggingface.co/datasets/zhangzicheng/qbench_tsv/resolve/main/Q-Bench1_VAL.tsv',
|
||||
'Q-Bench1_TEST': 'https://huggingface.co/datasets/zhangzicheng/qbench_tsv/resolve/main/Q-Bench1_TEST.tsv',
|
||||
# A-Bench
|
||||
'A-Bench_VAL': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_VAL.tsv',
|
||||
'A-Bench_TEST': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_TEST.tsv',
|
||||
# R-Bench
|
||||
'R-Bench-Dis': 'https://huggingface.co/datasets/lcysyzxdxc/R-Bench/blob/main/R-bench-dis.tsv',
|
||||
'R-Bench-Ref': 'https://huggingface.co/datasets/lcysyzxdxc/R-Bench/blob/main/R-bench-ref.tsv',
|
||||
# Other Benchmarks
|
||||
'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
|
||||
'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
|
||||
'AI2D_TEST_NO_MASK': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST_NO_MASK.tsv',
|
||||
'MMStar': 'https://opencompass.openxlab.space/utils/VLMEval/MMStar.tsv',
|
||||
'RealWorldQA': 'https://opencompass.openxlab.space/utils/VLMEval/RealWorldQA.tsv',
|
||||
'MLLMGuard_DS': 'https://opencompass.openxlab.space/utils/VLMEval/MLLMGuard_DS.tsv',
|
||||
'BLINK': 'https://opencompass.openxlab.space/utils/VLMEval/BLINK.tsv',
|
||||
'TaskMeAnything_v1_imageqa_random': (
|
||||
'https://huggingface.co/datasets/weikaih/TaskMeAnything-v1-imageqa-random/'
|
||||
'resolve/main/TaskMeAnything-v1-imageqa-random.tsv'
|
||||
),
|
||||
'A-OKVQA': 'https://huggingface.co/datasets/Allen8/A-OKVQA/resolve/main/a-okvqa.tsv',
|
||||
'WorldMedQA-V': 'https://opencompass.openxlab.space/utils/VLMEval/WorldMedQA-V.tsv',
|
||||
'VisOnlyQA-VLMEvalKit': (
|
||||
'https://huggingface.co/datasets/ryokamoi/VisOnlyQA_Eval_Real/'
|
||||
'resolve/main/visonlyqa_vlmevalkit.tsv'
|
||||
),
|
||||
'3DSRBench': (
|
||||
'https://huggingface.co/datasets/ccvl/3DSRBench/'
|
||||
'resolve/main/3dsrbench_v1_vlmevalkit_circular.tsv'
|
||||
),
|
||||
}
|
||||
|
||||
DATASET_MD5 = {
|
||||
# MMBench v1.0
|
||||
'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
|
||||
'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
|
||||
'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
|
||||
'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
|
||||
'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Internal Only
|
||||
'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Internal Only
|
||||
# MMBench v1.1
|
||||
'MMBench_DEV_EN_V11': '30c05be8f2f347a50be25aa067248184',
|
||||
'MMBench_TEST_EN_V11': '26f0f15381a21720255091d3e0316ce6',
|
||||
'MMBench_DEV_CN_V11': '593f9b5f6bea453d870a798b34ae4f37',
|
||||
'MMBench_TEST_CN_V11': '74bbe4556dac745613c7cbe5ad787050',
|
||||
'MMBench_V11': 'b9276414f57af1308dcc4d0cd9b42e7c', # Internal Only
|
||||
'MMBench_CN_V11': '95f6980dd1b4de38e3cbffe0305a3f25', # Internal Only
|
||||
# SEEDBench
|
||||
'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
|
||||
'SEEDBench2': '4ec15cf864c4f16274112284f531813e',
|
||||
'SEEDBench2_Plus': 'e32d3216dc4f452b0fe497a52015d1fd',
|
||||
# ScienceQA
|
||||
'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
|
||||
'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
|
||||
# MMT-Bench
|
||||
'MMT-Bench_ALL_MI': '5272157097e19cdd7cb41e412ab3b7c7',
|
||||
'MMT-Bench_ALL': 'b273a2f4c596fe4f2605de0494cd632f',
|
||||
'MMT-Bench_VAL_MI': 'c7d7b998eb5cd9aa36c7d4f721472462',
|
||||
'MMT-Bench_VAL': '8dd4b730f53dbf9c3aed90ca31c928e0',
|
||||
# AesBench
|
||||
'AesBench_VAL': '3edb0c319e9187aa0b97fe7a11700a8c',
|
||||
'AesBench_TEST': '58b1f7ba2cc32e1d68896d6ee716bbf8',
|
||||
# Q-Bench1
|
||||
'Q-Bench1_VAL': '837bdb6cd2da571713543462815187b7',
|
||||
'Q-Bench1_TEST': '15e759bfd58c9d5f30b23a317d347153',
|
||||
# A-Bench
|
||||
'A-Bench_VAL': '218563ec50d34bb336c814143a5bb9c1',
|
||||
'A-Bench_TEST': '567013fb033a20cf23f51d8e865bd16c',
|
||||
# R-Bench
|
||||
'R-Bench-Dis': 'd6e961dbfc43350688af2560226830b4',
|
||||
'R-Bench-Ref': '270c1cb555acb523f3fdb178ed57021d',
|
||||
# Other Benchmarks
|
||||
'CCBench': 'f5dde47f24dc5a6fb6e595b409b466ac',
|
||||
'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
|
||||
'AI2D_TEST_NO_MASK': 'fd8f463634d4fe9fbd23b876e8eea5be',
|
||||
'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
|
||||
'RealWorldQA': '4de008f55dc4fd008ca9e15321dc44b7',
|
||||
'MLLMGuard_DS': '975fc0dd7119386e198c37d71e274b3f',
|
||||
'BLINK': '3b6649b6a662184ea046908e5506260e',
|
||||
'TaskMeAnything_v1_imageqa_random': '023fef69e2ca21827afb77c5ec3bc889',
|
||||
'WorldMedQA-V': '441e63875e30c87f5750528b57b41285',
|
||||
"VisOnlyQA-VLMEvalKit": 'cf460a31d2acb8d3a7cecd0e69298bfa',
|
||||
'3DSRBench': '13a99f33164dc1b9faf0e8b8b01fd6f2',
|
||||
}
|
||||
|
||||
DATASET_URL.update(MMMB_URLS)
|
||||
DATASET_URL.update(MTL_MMBench_URLS)
|
||||
DATASET_MD5.update(MMMB_MD5)
|
||||
DATASET_MD5.update(MTL_MMBench_MD5)
|
||||
|
||||
def build_prompt(self, line):
|
||||
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
else:
|
||||
tgt_path = self.dump_image(line)
|
||||
|
||||
question = line['question']
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
options_prompt = 'Options:\n'
|
||||
for key, item in options.items():
|
||||
options_prompt += f'{key}. {item}\n'
|
||||
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
||||
prompt = ''
|
||||
if hint is not None:
|
||||
prompt += f'Hint: {hint}\n'
|
||||
prompt += f'Question: {question}\n'
|
||||
if len(options):
|
||||
prompt += options_prompt
|
||||
prompt += 'Please select the correct answer from the options above. \n'
|
||||
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
|
||||
return msgs
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
|
||||
# assert dataset is not None
|
||||
dataset_map = {
|
||||
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
|
||||
'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
|
||||
}
|
||||
dataset = self.dataset_name
|
||||
if dataset in dataset_map:
|
||||
dataset = dataset_map[dataset]
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
circular = False
|
||||
if listinstr(['mmbench', 'ccbench'], dataset.lower()):
|
||||
data = load(eval_file)
|
||||
data['index'] = [int(x) for x in data['index']]
|
||||
dump(data, eval_file)
|
||||
circular = True
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
||||
name_str = name_str_map[model] if model in name_str_map else model
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
|
||||
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
||||
|
||||
data = load(eval_file)
|
||||
data = data.sort_values(by='index')
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
# If not choice label, then use lower case
|
||||
for k in data.keys():
|
||||
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
||||
|
||||
meta = self.data
|
||||
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
||||
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
||||
for k in data_map:
|
||||
assert k in meta_q_map, (
|
||||
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
||||
)
|
||||
|
||||
if circular:
|
||||
data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
||||
else:
|
||||
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
||||
|
||||
# load split
|
||||
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
|
||||
# May have different report acc functions for different datasets
|
||||
if 'MMT' in dataset:
|
||||
acc = report_acc_MMT(data)
|
||||
else:
|
||||
acc = report_acc(data)
|
||||
|
||||
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
dump(acc, score_file)
|
||||
|
||||
if dataset == 'AesBench_VAL':
|
||||
warnings.warn('Note that AesBench VAL is just a toy version of AesBench TEST. For full results, \
|
||||
please evaluate on AesBench TEST. The AesBench TEST dataset is more than 20 times \
|
||||
larger than the VAL dataset and the leaderboard results are based on AesBench TEST.')
|
||||
if dataset == 'VisOnlyQA-VLMEvalKit':
|
||||
warnings.warn('Note that the results on VisOnlyQA-VLMEvalKit are different from the results on \
|
||||
the original VisOnlyQA. VisOnlyQA-VLMEvalKit does not include the \
|
||||
chemistry__shape_multi split and uses a different evaluation prompt. Please \
|
||||
explicitly specify the version of the dataset when you report results.')
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
class MMMUDataset(ImageMCQDataset):
|
||||
|
||||
DATASET_URL = {
|
||||
'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
|
||||
'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
|
||||
}
|
||||
|
||||
DATASET_MD5 = {
|
||||
'MMMU_DEV_VAL': '585e8ad75e73f75dcad265dfd0417d64',
|
||||
'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def split_MMMU(msgs):
|
||||
text, images = None, []
|
||||
for s in msgs:
|
||||
if s['type'] == 'image':
|
||||
images.append(s['value'])
|
||||
elif s['type'] == 'text':
|
||||
assert text is None
|
||||
text = s['value']
|
||||
text_segs = text.split('<image ')
|
||||
if len(text_segs) == 1:
|
||||
return msgs
|
||||
|
||||
segs = [dict(type='text', value=text_segs[0])]
|
||||
for i, seg in enumerate(text_segs):
|
||||
if i == 0:
|
||||
continue
|
||||
assert istype(seg[0], int) and seg[1] == '>'
|
||||
image_idx = int(seg[0]) - 1
|
||||
segs.append(dict(type='image', value=images[image_idx]))
|
||||
segs.append(dict(type='text', value=seg[2:]))
|
||||
return segs
|
||||
|
||||
def build_prompt(self, line):
|
||||
msgs = super().build_prompt(line)
|
||||
msgs = self.split_MMMU(msgs)
|
||||
return msgs
|
||||
|
||||
|
||||
class MUIRDataset(ImageMCQDataset):
|
||||
|
||||
DATASET_URL = {
|
||||
'MUIRBench': 'http://opencompass.openxxlab.com/utils/VLMEval/MUIRBench.tsv'
|
||||
}
|
||||
|
||||
DATASET_MD5 = {
|
||||
'MUIRBench': '2e5e6fd7699761b08a7cb3ab8c0c2ec8'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def split_MUIR(msgs):
|
||||
text, images = None, []
|
||||
|
||||
# Separate images and text from msgs
|
||||
for s in msgs:
|
||||
if s['type'] == 'image':
|
||||
images.append(s['value'])
|
||||
elif s['type'] == 'text':
|
||||
assert text is None # Ensure only one text entry is expected
|
||||
text = s['value']
|
||||
|
||||
# Split text by <image> tags
|
||||
text_segs = text.split('<image>')
|
||||
|
||||
# Initialize the segments list
|
||||
segs = []
|
||||
|
||||
# Iterate through the text segments and images
|
||||
for i, seg in enumerate(text_segs):
|
||||
# Append the image if this is not the first segment and there are still images left
|
||||
if i > 0 and i - 1 < len(images):
|
||||
segs.append(dict(type='image', value=images[i - 1]))
|
||||
# Append the text segment (if it's non-empty)
|
||||
if len(seg) > 0:
|
||||
segs.append(dict(type='text', value=seg))
|
||||
|
||||
return segs
|
||||
|
||||
def build_prompt(self, line):
|
||||
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
else:
|
||||
tgt_path = self.dump_image(line)
|
||||
|
||||
question = line['question']
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
# options_prompt = ''
|
||||
options_prompt = '\n'.join([f'{key}. {item}' for key, item in options.items()])
|
||||
# for key, item in options.items():
|
||||
# options_prompt += f'{key}. {item}\n'
|
||||
|
||||
prompt = ''
|
||||
|
||||
prompt += f'{question}\n'
|
||||
if len(options):
|
||||
prompt += options_prompt
|
||||
prompt += "\nAnswer with the option's letter from the given choices directly."
|
||||
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
|
||||
msgs = self.split_MUIR(msgs)
|
||||
return msgs
|
||||
|
||||
|
||||
class GMAIMMBenchDataset(ImageMCQDataset):
|
||||
|
||||
DATASET_URL = {
|
||||
'GMAI-MMBench_VAL': 'https://huggingface.co/datasets/VLMEval/GMAI-MMBench/resolve/main/GMAI-MMBench_VAL.tsv',
|
||||
'GMAI_mm_bench_TEST_part_1': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_1.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_2': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_2.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_3': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_3.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_4': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_4.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_5': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_5.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_6': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_6.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_7': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_7.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_8': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_8.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_9': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_9.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_10': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_10.tsv', # noqa: E501
|
||||
'GMAI_mm_bench_TEST_part_11': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_11.tsv', # noqa: E501
|
||||
}
|
||||
|
||||
DATASET_MD5 = {
|
||||
'GMAI-MMBench_VAL': '254bd581627866f1c499d3d6b4422324',
|
||||
'GMAI_mm_bench_TEST_part_1': '900d735231230a63f4ed45665c078ef4',
|
||||
'GMAI_mm_bench_TEST_part_2': '1b27ab621386945d7e4a765ad2d22b0e',
|
||||
'GMAI_mm_bench_TEST_part_3': '44bdc2b6267dd505d529b8cad06f0fb2',
|
||||
'GMAI_mm_bench_TEST_part_4': '5a04a04fcac9f1466709f242fdb80acb',
|
||||
'GMAI_mm_bench_TEST_part_5': 'c70baf8909eda9af0ddeab275c721336',
|
||||
'GMAI_mm_bench_TEST_part_6': '825abc39596b644dead9350d0cfa3b96',
|
||||
'GMAI_mm_bench_TEST_part_7': 'defb8aed2fb77365a76b6b9abd6a2701',
|
||||
'GMAI_mm_bench_TEST_part_8': 'ff490d60b85f2bb0abb67a435b298c65',
|
||||
'GMAI_mm_bench_TEST_part_9': 'ff67c86f40da93b09139ac1d1ba5dc6b',
|
||||
'GMAI_mm_bench_TEST_part_10': '3dae94627b9ac0fe00180d4780fbf6dc',
|
||||
'GMAI_mm_bench_TEST_part_11': 'd08dc813f0eb6bbab63cae2a9d113c4b',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['GMAI-MMBench_VAL', 'GMAI-MMBench_TEST']
|
||||
|
||||
def load_data(self, dataset):
|
||||
if dataset == 'GMAI-MMBench_VAL':
|
||||
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
||||
if file_size(data_path, 'GB') > 1:
|
||||
local_path = data_path.replace('.tsv', '_local.tsv')
|
||||
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL'):
|
||||
from ..tools import LOCALIZE
|
||||
LOCALIZE(data_path, local_path)
|
||||
data_path = local_path
|
||||
return load(data_path)
|
||||
elif dataset == 'GMAI-MMBench_TEST':
|
||||
dfs = []
|
||||
for part_num in range(1, 12):
|
||||
part_name = f'GMAI_mm_bench_TEST_part_{part_num}'
|
||||
url = self.DATASET_URL[part_name]
|
||||
file_md5 = self.DATASET_MD5.get(part_name)
|
||||
tsv_path = osp.join(LMUDataRoot(), f'{part_name}.tsv')
|
||||
if not osp.exists(tsv_path) or (file_md5 and md5(tsv_path) != file_md5):
|
||||
download_file(url, filename=tsv_path)
|
||||
local_path = tsv_path.replace('.tsv', '_local.tsv')
|
||||
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL'):
|
||||
from ..tools import LOCALIZE
|
||||
LOCALIZE(tsv_path, local_path)
|
||||
tsv_path = local_path
|
||||
# 加载数据
|
||||
df = load(tsv_path)
|
||||
dfs.append(df)
|
||||
# 合并所有数据
|
||||
data = pd.concat(dfs, ignore_index=True)
|
||||
return data
|
||||
else:
|
||||
raise ValueError(f"未知的数据集:{dataset}")
|
||||
|
||||
def report_acc_by_groups(self, df, group_column):
|
||||
res = defaultdict(list)
|
||||
|
||||
# Check for the 'split' column
|
||||
if 'split' in df:
|
||||
splits = list(set(df['split']))
|
||||
res['split'] = splits
|
||||
else:
|
||||
df['split'] = ['none'] * len(df)
|
||||
res['split'] = ['none']
|
||||
|
||||
res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
|
||||
|
||||
if group_column not in df:
|
||||
raise ValueError(f"Column '{group_column}' not found in dataframe.") # noqa: E713
|
||||
|
||||
abilities = list(set(df[group_column]))
|
||||
abilities = ['None' if isinstance(ab, float) and pd.isna(ab) else ab for ab in abilities]
|
||||
abilities.sort()
|
||||
|
||||
for ab in abilities:
|
||||
ab_name = ab
|
||||
sub_df = df[df[group_column] == ab]
|
||||
res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
|
||||
|
||||
return pd.DataFrame(res)
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.multiple_choice import report_acc, mcq_vanilla_eval
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
||||
name_str = name_str_map[model] if model in name_str_map else model
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
|
||||
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
||||
|
||||
data = load(eval_file)
|
||||
data = data.sort_values(by='index')
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
# If not choice label, then use lower case
|
||||
for k in data.keys():
|
||||
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
||||
|
||||
meta = self.data
|
||||
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
||||
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
||||
for k in data_map:
|
||||
assert k in meta_q_map, (
|
||||
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
||||
)
|
||||
|
||||
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
||||
|
||||
# load split
|
||||
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
|
||||
acc = report_acc(data)
|
||||
|
||||
for group_col in ['clinical vqa task', 'department', 'perceptual granularity']:
|
||||
acc_grouped = self.report_acc_by_groups(data, group_col)
|
||||
score_file_grouped = eval_file.replace(f'.{suffix}', f'_{group_col}_acc.csv')
|
||||
dump(acc_grouped, score_file_grouped)
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
class MMERealWorld(ImageMCQDataset):
|
||||
|
||||
TYPE = 'MMERealWorld'
|
||||
|
||||
DATASET_MD5 = {
|
||||
'MME-RealWorld': '271c33ec814c39533c467ec6fb8a6f36',
|
||||
'MME-RealWorld-Lite': '4c17057d7d3b6c4a0d4397c3dae0881c',
|
||||
'MME-RealWorld-CN': 'daaa763d52a760a38606d5dedb3fe444',
|
||||
}
|
||||
SYS = {
|
||||
'MME-RealWorld': (
|
||||
'Select the best answer to the above multiple-choice question based on the image. '
|
||||
'Respond with only the letter (A, B, C, D, or E) of the correct option. \n'
|
||||
'The best answer is:'
|
||||
),
|
||||
'MME-RealWorld-Lite': (
|
||||
'Select the best answer to the above multiple-choice question based on the image. '
|
||||
'Respond with only the letter (A, B, C, D, or E) of the correct option. \n'
|
||||
'The best answer is:'
|
||||
),
|
||||
'MME-RealWorld-CN': (
|
||||
'根据图像选择上述多项选择题的最佳答案。只需回答正确选项的字母(A, B, C, D 或 E)。\n'
|
||||
'最佳答案为:'
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['MME-RealWorld', 'MME-RealWorld-CN', 'MME-RealWorld-Lite',]
|
||||
|
||||
def load_data(
|
||||
self, dataset="MME-RealWorld", repo_id="yifanzhang114/MME-RealWorld-Base64"
|
||||
):
|
||||
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f"{dataset}.tsv")
|
||||
|
||||
if not os.path.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.DATASET_MD5[dataset]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def generate_tsv(pth):
|
||||
tsv_file = os.path.join(pth, f"{dataset}.tsv")
|
||||
|
||||
if os.path.exists(tsv_file):
|
||||
print(f"{tsv_file} already exists.")
|
||||
return
|
||||
|
||||
json_dir = os.path.join(pth, dataset)
|
||||
json_files = [f for f in os.listdir(json_dir) if f.endswith(".json")]
|
||||
|
||||
data_list = []
|
||||
for json_file in json_files:
|
||||
with open(os.path.join(json_dir, json_file), "r") as f:
|
||||
data = json.load(f)
|
||||
for item in tqdm(data):
|
||||
choice_prompt = (
|
||||
"The choices are listed below:\n"
|
||||
if dataset in ["MME-RealWorld", "MME-RealWorld-Lite"]
|
||||
else "选项如下所示:\n"
|
||||
)
|
||||
data_list.append(
|
||||
{
|
||||
"index": item["index"],
|
||||
"image": item["image"],
|
||||
"question": item["question"],
|
||||
"multi-choice options": choice_prompt
|
||||
+ "\n".join(item["multi-choice options"]),
|
||||
"A": item["multi-choice options"][0][4:],
|
||||
"B": item["multi-choice options"][1][4:],
|
||||
"C": item["multi-choice options"][2][4:],
|
||||
"D": item["multi-choice options"][3][4:],
|
||||
"E": item["multi-choice options"][4][4:],
|
||||
"answer": item["answer"],
|
||||
"category": item["category"],
|
||||
"l2-category": item["l2-category"],
|
||||
}
|
||||
)
|
||||
df = pd.DataFrame(data_list)
|
||||
df.to_csv(tsv_file, sep="\t", index=False)
|
||||
print(f"TSV file saved to {tsv_file}")
|
||||
|
||||
# Check if dataset is cached and has integrity
|
||||
if dataset == "MME-RealWorld-Lite":
|
||||
url = 'https://huggingface.co/datasets/yifanzhang114/MME-RealWorld-Base64/resolve/main/mme_realworld_lite.tsv' # noqa: E501
|
||||
file_md5 = (
|
||||
self.DATASET_MD5[dataset] if dataset in self.DATASET_MD5 else None
|
||||
)
|
||||
datas = self.prepare_tsv(url, file_md5)
|
||||
choice_prompt = "The choices are listed below:\n"
|
||||
for index, item in datas.iterrows():
|
||||
options = eval(item["multi-choice options"])
|
||||
datas.loc[index, "multi-choice options"] = choice_prompt + "\n".join(
|
||||
options
|
||||
)
|
||||
datas.loc[index, "A"] = options[0][4:]
|
||||
datas.loc[index, "B"] = options[1][4:]
|
||||
datas.loc[index, "C"] = options[2][4:]
|
||||
datas.loc[index, "D"] = options[3][4:]
|
||||
datas.loc[index, "E"] = options[4][4:]
|
||||
return datas
|
||||
|
||||
update_flag = False
|
||||
cache_path = get_cache_path(repo_id)
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
else:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download or find the dataset path
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
|
||||
generate_tsv(dataset_path)
|
||||
update_flag = True
|
||||
|
||||
data_path = os.path.join(dataset_path, f"{dataset}.tsv")
|
||||
if file_size(data_path, "GB") > 1:
|
||||
local_path = data_path.replace(".tsv", "_local.tsv")
|
||||
if (
|
||||
not osp.exists(local_path)
|
||||
or os.environ.get("FORCE_LOCAL", None)
|
||||
or update_flag
|
||||
):
|
||||
from vlmeval.tools import LOCALIZE
|
||||
|
||||
LOCALIZE(data_path, local_path)
|
||||
data_path = local_path
|
||||
return load(data_path)
|
||||
|
||||
def post_build(self, dataset):
|
||||
self.TYPE = 'MMERealWorld'
|
||||
|
||||
# Given one data record, return the built prompt (a multi-modal message), can override
|
||||
def build_prompt(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
else:
|
||||
tgt_path = self.dump_image(line)
|
||||
|
||||
question = line['question']
|
||||
|
||||
choice_prompt = line['multi-choice options'] + '\n'
|
||||
question += ' ' + choice_prompt + self.SYS[self.dataset_name]
|
||||
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=question))
|
||||
return msgs
|
||||
|
||||
# It returns a dictionary
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.multiple_choice import extract_characters_regex, get_dimension_rating
|
||||
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
||||
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
||||
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
||||
|
||||
if not osp.exists(score_file):
|
||||
|
||||
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
||||
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
||||
|
||||
data = load(eval_file)
|
||||
cnt_rejected = 0
|
||||
data_un = data[~pd.isna(data['prediction'])]
|
||||
|
||||
for idx in data['index']:
|
||||
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
||||
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
|
||||
|
||||
extract_pred = extract_characters_regex(pred)
|
||||
if extract_pred == '':
|
||||
cnt_rejected += 1
|
||||
data.loc[data['index'] == idx, 'score'] = 0
|
||||
else:
|
||||
data.loc[data['index'] == idx, 'score'] = int(extract_pred == ans)
|
||||
|
||||
print(
|
||||
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
||||
f'failed to obtain the score for another {cnt_rejected} questions. '
|
||||
f'Those questions will be counted as 0 score in ALL rating.'
|
||||
)
|
||||
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
dump(rating, tgt_file)
|
||||
return rating
|
||||
|
||||
|
||||
class HRBenchDataset(ImageMCQDataset):
|
||||
|
||||
DATASET_URL = {
|
||||
'HRBench4K': 'https://huggingface.co/datasets/DreamMr/HR-Bench/resolve/main/hr_bench_4k.tsv',
|
||||
'HRBench8K': 'https://huggingface.co/datasets/DreamMr/HR-Bench/resolve/main/hr_bench_8k.tsv',
|
||||
}
|
||||
|
||||
DATASET_MD5 = {
|
||||
'HRBench4K': 'f6b041b03d49543494b8a56d2e35be65',
|
||||
'HRBench8K': '274c9c7f89329b804a4723178a00219c',
|
||||
}
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
assert os.path.exists(eval_file), '{} does not exist!'.format(eval_file)
|
||||
from .utils.multiple_choice import mcq_vanilla_eval
|
||||
from .utils.hrbench import report_acc_hrbench
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
model = judge_kwargs.get('model', 'extract_matching')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
||||
name_str = name_str_map[model] if model in name_str_map else model
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
|
||||
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
||||
|
||||
data = load(eval_file)
|
||||
data = data.sort_values(by='index')
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
# If not choice label, then use lower case
|
||||
for k in data.keys():
|
||||
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
||||
|
||||
meta = self.data
|
||||
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
||||
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
||||
for k in data_map:
|
||||
assert k in meta_q_map, (
|
||||
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
||||
)
|
||||
|
||||
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
|
||||
if osp.exists(score_file):
|
||||
acc = load(score_file)
|
||||
return acc
|
||||
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
||||
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
|
||||
acc = report_acc_hrbench(data)
|
||||
|
||||
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
dump(acc, score_file)
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
class CustomMCQDataset(ImageMCQDataset):
|
||||
|
||||
def load_data(self, dataset):
|
||||
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
||||
|
||||
if file_size(data_path, 'GB') > 1:
|
||||
local_path = data_path.replace('.tsv', '_local.tsv')
|
||||
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
|
||||
from ..tools import LOCALIZE
|
||||
LOCALIZE(data_path, local_path)
|
||||
data_path = local_path
|
||||
return load(data_path)
|
||||
|
||||
|
||||
class NaturalBenchDataset(ImageMCQDataset):
|
||||
|
||||
DATASET_URL = {
|
||||
'NaturalBenchDataset': (
|
||||
'https://huggingface.co/datasets/BaiqiL/'
|
||||
'NaturalBench/resolve/main/NaturalBenchDataset.tsv'
|
||||
),
|
||||
}
|
||||
DATASET_MD5 = {
|
||||
'NaturalBenchDataset':'dbe25b044bc35696426381e9ba4fe930',
|
||||
}
|
||||
|
||||
def build_prompt(self, line):
|
||||
SUFFIX_FOR_VQA = {
|
||||
"yes_no": "Please answer Yes or No.",
|
||||
"multiple_choice": "Please output the letter corresponding to the correct option."
|
||||
}
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
else:
|
||||
tgt_path = self.dump_image(line)
|
||||
|
||||
question = line['question']
|
||||
prompt = f'{question} {SUFFIX_FOR_VQA[line["type"]]}'
|
||||
msgs = []
|
||||
if isinstance(tgt_path, list):
|
||||
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
||||
else:
|
||||
msgs = [dict(type='image', value=tgt_path)]
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
|
||||
return msgs
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.naturalbench import extract_answer, get_scores
|
||||
|
||||
data = load(eval_file)
|
||||
data = data.sort_values(by='index')
|
||||
predictions = [str(x) for x in data['prediction']]
|
||||
answers = [str(x) for x in data['answer']]
|
||||
indexs = [str(x) for x in data['index']]
|
||||
meta = self.data
|
||||
types = [str(x) for x in meta['type']]
|
||||
results = {}
|
||||
assert len(predictions) == len(answers) == len(indexs) == len(types) == (1900 * 4)
|
||||
number_answered_samples = len(predictions) // 4
|
||||
for i in range(number_answered_samples):
|
||||
results[i] = {
|
||||
"q0_i0": extract_answer(predictions[i * 4], types[i * 4]),
|
||||
"q0_i1": extract_answer(predictions[i * 4 + 1], types[i * 4 + 1]),
|
||||
"q1_i0": extract_answer(predictions[i * 4 + 2], types[i * 4 + 2]),
|
||||
"q1_i1": extract_answer(predictions[i * 4 + 3], types[i * 4 + 3])
|
||||
}
|
||||
|
||||
scores = get_scores(results)
|
||||
print(scores)
|
||||
score_file = 'NaturalBench_acc.csv'
|
||||
df = pd.DataFrame(list(scores.items()), columns=['Metric', 'Score'])
|
||||
dump(df, score_file)
|
||||
|
||||
return scores
|
||||
128
eval_mm/vlmevalkit/vlmeval/dataset/image_mt.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from .image_base import ImageBaseDataset
|
||||
from .utils.judge_util import build_judge
|
||||
from ..smp import *
|
||||
from ..utils import track_progress_rich
|
||||
|
||||
|
||||
class ImageMTDataset(ImageBaseDataset):
|
||||
|
||||
TYPE = 'MT'
|
||||
|
||||
def build_prompt(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
if self.meta_only:
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
else:
|
||||
tgt_path = self.dump_image(line)
|
||||
|
||||
questions = toliststr(line['question'])
|
||||
if 'answer' in line:
|
||||
answers = toliststr(line['answer'])
|
||||
else:
|
||||
answers = [''] * len(questions)
|
||||
assert len(questions) == len(answers)
|
||||
|
||||
dlgs, pics_number = [], 0
|
||||
for i in range(len(questions)):
|
||||
q, a = questions[i], answers[i]
|
||||
if '<ImageHere>' in q:
|
||||
content = []
|
||||
tag_number = q.count('<ImageHere>')
|
||||
images = tgt_path[pics_number: pics_number + tag_number]
|
||||
pics_number += tag_number
|
||||
q_split = q.split('<ImageHere>')
|
||||
for i in range(tag_number):
|
||||
qsp, im = q_split[i], images[i]
|
||||
if qsp != '':
|
||||
content.append(dict(type='text', value=qsp))
|
||||
content.append(dict(type='image', value=im))
|
||||
if q_split[-1] != '':
|
||||
content.append(dict(type='text', value=q_split[-1]))
|
||||
else:
|
||||
content = [dict(type='text', value=q)]
|
||||
dlgs.append(dict(role='user', content=content))
|
||||
assert '<ImageHere>' not in a, 'We currently do not support images in the answer. '
|
||||
content = [dict(type='text', value=a)]
|
||||
dlgs.append(dict(role='assistant', content=content))
|
||||
return dlgs
|
||||
|
||||
|
||||
class MMDUDataset(ImageMTDataset):
|
||||
|
||||
DATASET_URL = {'MMDU': 'https://opencompass.openxlab.space/utils/VLMEval/MMDU.tsv'}
|
||||
DATASET_MD5 = {'MMDU': '848b635a88a078f49aebcc6e39792061'}
|
||||
DIMS = [
|
||||
'Creativity', 'Richness', 'Visual Perception', 'Logical Coherence',
|
||||
'Answer Accuracy', 'Image Relationship Understanding', 'Overall Score'
|
||||
]
|
||||
|
||||
def calculat_metric(self, ans):
|
||||
all = defaultdict(lambda: 0)
|
||||
tot = defaultdict(lambda: 0)
|
||||
valid = defaultdict(lambda: 0)
|
||||
for k in ans:
|
||||
res = ans[k]['res']
|
||||
assert isinstance(res, pd.DataFrame)
|
||||
lt = len(res)
|
||||
for i in range(lt):
|
||||
line = res.iloc[i]
|
||||
for k in self.DIMS:
|
||||
tot[k] += 1
|
||||
if k in line and line[k] is not None:
|
||||
try:
|
||||
score = int(line[k])
|
||||
score = np.clip(score, 0, 10)
|
||||
all[k] += score
|
||||
valid[k] += 1
|
||||
except Exception as e:
|
||||
print(f'Failed to parse the score: {str(e)}')
|
||||
sp1 = {'set': 'all'}
|
||||
sp1.update({k: all[k] / tot[k] * 10 for k in self.DIMS})
|
||||
sp2 = {'set': 'valid'}
|
||||
sp2.update({k: all[k] / valid[k] * 10 for k in self.DIMS})
|
||||
|
||||
return pd.DataFrame([sp1, sp2])
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
suffix = eval_file.split('.')[-1]
|
||||
model = judge_kwargs['model']
|
||||
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
||||
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
data = load(eval_file)
|
||||
model = judge_kwargs.pop('model', 'gpt-4o')
|
||||
judge_model = build_judge(model=model, **judge_kwargs)
|
||||
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
tups = [(judge_model, line) for line in lines]
|
||||
indices = [line['index'] for line in lines]
|
||||
|
||||
ans = {}
|
||||
if osp.exists(tmp_file):
|
||||
ans = load(tmp_file)
|
||||
|
||||
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
||||
indices = [i for i in indices if i not in ans]
|
||||
|
||||
from .utils.mmdu import mmdu_score
|
||||
|
||||
if len(indices):
|
||||
new_results = track_progress_rich(
|
||||
mmdu_score,
|
||||
tups,
|
||||
nproc=nproc,
|
||||
chunksize=nproc,
|
||||
keys=indices,
|
||||
save=tmp_file,)
|
||||
ans = load(tmp_file)
|
||||
for k, v in zip(indices, new_results):
|
||||
assert k in ans
|
||||
|
||||
metric = self.calculat_metric(ans)
|
||||
dump(metric, score_file)
|
||||
return metric
|
||||
1475
eval_mm/vlmevalkit/vlmeval/dataset/image_vqa.py
Normal file
95
eval_mm/vlmevalkit/vlmeval/dataset/image_yorn.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from ..smp import *
|
||||
from ..utils import *
|
||||
from .image_base import ImageBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
|
||||
|
||||
class ImageYORNDataset(ImageBaseDataset):
|
||||
|
||||
TYPE = 'Y/N'
|
||||
|
||||
DATASET_URL = {
|
||||
'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
|
||||
'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
|
||||
'POPE': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
|
||||
'AMBER': 'https://huggingface.co/datasets/yifanzhang114/AMBER_base64/resolve/main/AMBER.tsv',
|
||||
}
|
||||
|
||||
DATASET_MD5 = {
|
||||
'MME': 'b36b43c3f09801f5d368627fb92187c3',
|
||||
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
|
||||
'POPE': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
|
||||
'AMBER': '970d94c0410916166e0a76ba75da7934',
|
||||
}
|
||||
|
||||
# It returns a dataframe
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.yorn import YOrN_Extraction, YOrN_auxeval
|
||||
from .utils.yorn import default_rating, MME_rating, Hallusion_rating, POPE_rating, AMBER_rating
|
||||
|
||||
dataset = self.dataset_name
|
||||
data = load(eval_file)
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
storage = eval_file.replace('.xlsx', '_auxmatch.xlsx')
|
||||
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
if not osp.exists(storage):
|
||||
ans_map = {k: YOrN_Extraction(v) for k, v in zip(data['index'], data['prediction'])}
|
||||
if osp.exists(tmp_file):
|
||||
tmp = load(tmp_file)
|
||||
for k in tmp:
|
||||
if ans_map[k] == 'Unknown' and tmp[k] != 'Unknown':
|
||||
ans_map[k] = tmp[k]
|
||||
|
||||
data['extracted'] = [ans_map[x] for x in data['index']]
|
||||
unknown = data[data['extracted'] == 'Unknown']
|
||||
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
model = None
|
||||
warnings.warn('OPENAI_API_KEY is not working properly, will use exact matching for evaluation')
|
||||
|
||||
if model is not None:
|
||||
lt = len(unknown)
|
||||
lines = [unknown.iloc[i] for i in range(lt)]
|
||||
tups = [(model, line) for line in lines]
|
||||
indices = list(unknown['index'])
|
||||
if len(tups):
|
||||
res = track_progress_rich(
|
||||
YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
|
||||
for k, v in zip(indices, res):
|
||||
ans_map[k] = v
|
||||
|
||||
data['extracted'] = [ans_map[x] for x in data['index']]
|
||||
dump(data, storage)
|
||||
|
||||
data = load(storage)
|
||||
if listinstr(['AMBER'], dataset):
|
||||
data['score'] = (data['answer'].str.lower() == data['extracted'].str.lower())
|
||||
else:
|
||||
data['score'] = (data['answer'] == data['extracted'])
|
||||
dump(data, storage)
|
||||
|
||||
if dataset is not None and listinstr(['MME'], dataset):
|
||||
score = MME_rating(storage)
|
||||
elif dataset is not None and listinstr(['Hallusion'], dataset):
|
||||
score = Hallusion_rating(storage)
|
||||
elif dataset is not None and listinstr(['POPE'], dataset):
|
||||
score = POPE_rating(storage)
|
||||
elif dataset is not None and listinstr(['AMBER'], dataset):
|
||||
score = AMBER_rating(storage)
|
||||
else:
|
||||
score = default_rating(storage)
|
||||
|
||||
score_tgt = eval_file.replace('.xlsx', '_score.csv')
|
||||
dump(score, score_tgt)
|
||||
return score
|
||||
328
eval_mm/vlmevalkit/vlmeval/dataset/longvideobench.py
Normal file
@@ -0,0 +1,328 @@
|
||||
from huggingface_hub import snapshot_download
|
||||
from ..smp import *
|
||||
from .video_base import VideoBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from glob import glob
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
def timestamp_to_seconds(timestamp):
|
||||
# Split the timestamp into hours, minutes, and seconds
|
||||
h, m, s = timestamp.split(":")
|
||||
# Convert hours, minutes, and total seconds (including fractions) to float and compute total seconds
|
||||
total_seconds = int(h) * 3600 + int(m) * 60 + float(s)
|
||||
return total_seconds
|
||||
|
||||
|
||||
def uniformly_subsample(lst, K):
|
||||
n = len(lst)
|
||||
if K >= n:
|
||||
return lst
|
||||
step = n / K
|
||||
return [lst[int(i * step)] for i in range(K)]
|
||||
|
||||
|
||||
def insert_subtitles_into_frames(
|
||||
frames,
|
||||
frame_timestamps,
|
||||
subtitles,
|
||||
starting_timestamp_for_subtitles,
|
||||
duration,
|
||||
):
|
||||
interleaved_list = []
|
||||
cur_i = 0
|
||||
|
||||
for subtitle in subtitles:
|
||||
if "timestamp" in subtitle:
|
||||
start, end = subtitle["timestamp"]
|
||||
|
||||
if not isinstance(end, float):
|
||||
end = duration
|
||||
|
||||
start -= starting_timestamp_for_subtitles
|
||||
end -= starting_timestamp_for_subtitles
|
||||
|
||||
subtitle_timestamp = (start + end) / 2
|
||||
subtitle_text = subtitle["text"]
|
||||
else:
|
||||
start, end = subtitle["start"], subtitle["end"]
|
||||
start = timestamp_to_seconds(start)
|
||||
end = timestamp_to_seconds(end)
|
||||
start -= starting_timestamp_for_subtitles
|
||||
end -= starting_timestamp_for_subtitles
|
||||
|
||||
subtitle_timestamp = (start + end) / 2
|
||||
subtitle_text = subtitle["line"]
|
||||
|
||||
for i, (frame, frame_timestamp) in enumerate(
|
||||
zip(frames[cur_i:], frame_timestamps[cur_i:])
|
||||
):
|
||||
if frame_timestamp <= subtitle_timestamp:
|
||||
# print("frame:", frame_timestamp)
|
||||
interleaved_list.append({"type": "image", "value": frame})
|
||||
cur_i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if end - start < 1:
|
||||
end = subtitle_timestamp + 0.5
|
||||
start = subtitle_timestamp - 0.5
|
||||
|
||||
covering_frames = False
|
||||
for frame, frame_timestamp in zip(frames, frame_timestamps):
|
||||
if frame_timestamp < end and frame_timestamp > start:
|
||||
covering_frames = True
|
||||
break
|
||||
|
||||
if covering_frames:
|
||||
interleaved_list.append({"type": "text", "value": subtitle_text + "\n"})
|
||||
else:
|
||||
pass
|
||||
|
||||
for i, (frame, frame_timestamp) in enumerate(
|
||||
zip(frames[cur_i:], frame_timestamps[cur_i:])
|
||||
):
|
||||
interleaved_list.append({"type": "image", "value": frame})
|
||||
return interleaved_list
|
||||
|
||||
|
||||
class LongVideoBench(VideoBaseDataset):
|
||||
|
||||
MD5 = '82905eae3a5ae7383c5a8ee9655e1ab9'
|
||||
SYS = ''
|
||||
|
||||
TYPE = 'Video-MCQ'
|
||||
|
||||
def __init__(self, dataset='LongVideoBench', use_subtitle=False, nframe=0, fps=-1):
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
self.use_subtitle = use_subtitle
|
||||
self.dataset_name = dataset
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['LongVideoBench']
|
||||
|
||||
def prepare_dataset(self, dataset_name='LongVideoBench', repo_id='longvideobench/LongVideoBench'):
|
||||
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
|
||||
if not osp.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.MD5:
|
||||
print("md5 mismatch", md5(data_file), self.MD5)
|
||||
return False
|
||||
data = load(data_file)
|
||||
for video_pth in data['video_path']:
|
||||
if not osp.exists(osp.join(pth, video_pth)):
|
||||
print(video_pth, "is not found")
|
||||
return False
|
||||
return True
|
||||
|
||||
if modelscope_flag_set():
|
||||
repo_id = "AI-ModelScope/LongVideoBench"
|
||||
|
||||
cache_path = get_cache_path(repo_id)
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
def generate_tsv(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if osp.exists(data_file) and md5(data_file) == self.MD5:
|
||||
return
|
||||
|
||||
data_file = pd.read_json(osp.join(pth, 'lvb_val.json'))
|
||||
data_file = data_file.assign(index=range(len(data_file)))
|
||||
data_file['video'] = data_file['video_id']
|
||||
data_file['video_path'] = data_file['video_path'].apply(lambda x: f'./videos/{x}')
|
||||
|
||||
data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False)
|
||||
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_snapshot_download(dataset_id=repo_id)
|
||||
else:
|
||||
snapshot_download(repo_id=repo_id, repo_type='dataset')
|
||||
print("All videos are downloaded for LongVideoBench")
|
||||
|
||||
if not glob(osp.join(cache_path, "videos")):
|
||||
tar_files = glob(osp.join(cache_path, "**/*.tar*"), recursive=True)
|
||||
|
||||
def untar_video_data(tar_file, cache_dir):
|
||||
import tarfile
|
||||
with tarfile.open(tar_file, "r") as tar_ref:
|
||||
tar_ref.extractall(cache_dir)
|
||||
print(f"Extracted all files from {tar_file} to {cache_dir}")
|
||||
|
||||
def concat_tar_parts(tar_parts, output_tar):
|
||||
with open(output_tar, "wb") as out_tar:
|
||||
from tqdm import tqdm
|
||||
for part in tqdm(sorted(tar_parts)):
|
||||
with open(part, "rb") as part_file:
|
||||
out_tar.write(part_file.read())
|
||||
print(f"Concatenated parts {tar_parts} into {output_tar}")
|
||||
|
||||
tar_parts_dict = {}
|
||||
|
||||
# Group tar parts together
|
||||
for tar_file in tar_files:
|
||||
base_name = tar_file.split(".tar")[0]
|
||||
if base_name not in tar_parts_dict:
|
||||
tar_parts_dict[base_name] = []
|
||||
tar_parts_dict[base_name].append(tar_file)
|
||||
|
||||
# Concatenate and untar split parts
|
||||
for base_name, parts in tar_parts_dict.items():
|
||||
print(f"Extracting following tar files: {parts}")
|
||||
output_tar = base_name + ".tar"
|
||||
if not osp.exists(output_tar):
|
||||
print('Start concatenating tar files')
|
||||
|
||||
concat_tar_parts(parts, output_tar)
|
||||
print('Finish concatenating tar files')
|
||||
|
||||
if not osp.exists(osp.join(cache_path, osp.basename(base_name))):
|
||||
untar_video_data(output_tar, cache_path)
|
||||
|
||||
print('All videos are extracted for LongVideoBench')
|
||||
|
||||
dataset_path = cache_path
|
||||
generate_tsv(dataset_path)
|
||||
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
|
||||
return dict(data_file=data_file, root=dataset_path)
|
||||
|
||||
def save_video_frames(self, video_path, video_llm=False):
|
||||
|
||||
vid_path = osp.join(self.data_root, video_path)
|
||||
vid = decord.VideoReader(vid_path)
|
||||
video_info = {
|
||||
'fps': vid.get_avg_fps(),
|
||||
'n_frames': len(vid),
|
||||
}
|
||||
if self.nframe > 0 and self.fps < 0:
|
||||
step_size = len(vid) / (self.nframe + 1)
|
||||
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
|
||||
frame_paths = self.frame_paths(video_path[:-4])
|
||||
elif self.fps > 0:
|
||||
# not constrained by num_frames, get frames by fps
|
||||
total_duration = video_info['n_frames'] / video_info['fps']
|
||||
required_frames = int(total_duration * self.fps)
|
||||
step_size = video_info['fps'] / self.fps
|
||||
indices = [int(i * step_size) for i in range(required_frames)]
|
||||
frame_paths = self.frame_paths_fps(video_path[:-4], len(indices))
|
||||
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
|
||||
if not flag:
|
||||
images = [vid[i].asnumpy() for i in indices]
|
||||
images = [Image.fromarray(arr) for arr in images]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth) and not video_llm:
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths, indices, video_info
|
||||
|
||||
# def save_video_into_images(self, line, num_frames=8):
|
||||
# frame_paths, indices, video_info = self.save_video_frames(line['video_path'], num_frames)
|
||||
# return frame_paths
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
|
||||
frames, indices, video_info = self.save_video_frames(line['video_path'], video_llm)
|
||||
fps = video_info["fps"]
|
||||
|
||||
message = [dict(type='text', value=self.SYS)]
|
||||
if video_llm:
|
||||
message.append(dict(type='video', value=osp.join(self.data_root, line['video_path'])))
|
||||
else:
|
||||
if not self.use_subtitle:
|
||||
with open(osp.join(self.data_root, "subtitles", line["subtitle_path"])) as f:
|
||||
subtitles = json.load(f)
|
||||
|
||||
frame_message = insert_subtitles_into_frames(
|
||||
frames,
|
||||
[ind_ / fps for ind_ in indices],
|
||||
subtitles,
|
||||
line["starting_timestamp_for_subtitles"],
|
||||
line["duration"]
|
||||
)
|
||||
|
||||
message += frame_message
|
||||
else:
|
||||
for im in frames:
|
||||
message.append(dict(type='image', value=im))
|
||||
|
||||
line['question'] += '\n' + '\n'.join(
|
||||
["{}. {}".format(chr(ord("A") + i), cand) for i, cand in enumerate(eval(line['candidates']))]
|
||||
)
|
||||
prompt = line["question"] + "\nAnswer with the option's letter from the given choices directly."
|
||||
message.append(dict(type='text', value=prompt))
|
||||
return message
|
||||
|
||||
# It returns a dictionary
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.longvideobench import get_dimension_rating, extract_characters_regex, extract_option
|
||||
|
||||
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
||||
|
||||
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
||||
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
||||
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
||||
|
||||
if not osp.exists(score_file):
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
||||
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
||||
|
||||
data = load(eval_file)
|
||||
data_un = data[~pd.isna(data['prediction'])]
|
||||
|
||||
for idx in data['index']:
|
||||
ans = data.loc[data['index'] == idx, 'correct_choice'].values[0]
|
||||
ans = chr(ord("A") + ans)
|
||||
pred = str(data.loc[data['index'] == idx, 'prediction'].values[0])
|
||||
|
||||
if extract_characters_regex(pred) == '':
|
||||
extract_pred = extract_option(
|
||||
model,
|
||||
data.loc[data['index'] == idx].to_dict(orient='records')[0],
|
||||
'LongVideoBench'
|
||||
)
|
||||
data.loc[idx, 'score'] = int(extract_pred == ans)
|
||||
else:
|
||||
data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans)
|
||||
|
||||
rejected = [x for x in data['score'] if x == -1]
|
||||
|
||||
print(
|
||||
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
||||
f'failed to obtain the score for another {len(rejected)} questions. '
|
||||
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
||||
)
|
||||
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
dump(rating, tgt_file)
|
||||
return rating
|
||||
167
eval_mm/vlmevalkit/vlmeval/dataset/miabench.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .image_base import ImageBaseDataset
|
||||
from ..smp import *
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from ..utils import track_progress_rich
|
||||
|
||||
|
||||
def generate_prompt(d):
|
||||
question = d['question']
|
||||
weights = eval(d['component_weight'])
|
||||
components = eval(d['components'])
|
||||
num_of_component = int(d['num_of_component'])
|
||||
response = d['prediction']
|
||||
|
||||
if num_of_component == 1:
|
||||
components = f"The first component is: '{components[0]}'. "
|
||||
score = f"The first component is worth: {weights[0]} scores. "
|
||||
elif num_of_component == 2:
|
||||
components = f"The first component is: '{components[0]}', and the second component is '{components[1]}'. "
|
||||
score = f"The first and second component is each worth {weights[0]} and {weights[1]} scores. "
|
||||
elif num_of_component == 3:
|
||||
components = (
|
||||
f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
|
||||
f"and the third component is '{components[2]}'. "
|
||||
)
|
||||
score = (
|
||||
"The first, second, and third component is each worth "
|
||||
f"{weights[0]}, {weights[1]}, and {weights[2]} scores."
|
||||
)
|
||||
elif num_of_component == 4:
|
||||
components = (
|
||||
f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
|
||||
f"and the third component is '{components[2]}', and the fourth component is '{components[3]}'. "
|
||||
)
|
||||
score = (
|
||||
"The first, second, third, and fourth component is each worth "
|
||||
f"{weights[0]}, {weights[1]}, {weights[2]}, and {weights[3]} scores."
|
||||
)
|
||||
elif num_of_component == 5:
|
||||
components = (
|
||||
f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
|
||||
f"and the third component is '{components[2]}', and the fourth component is '{components[3]}', "
|
||||
f"and the fifth component is '{components[4]}'. "
|
||||
)
|
||||
score = (
|
||||
"The first, second, third, fourth, and fifth component is each worth "
|
||||
f"{weights[0]}, {weights[1]}, {weights[2]}, {weights[3]}, and {weights[4]} scores."
|
||||
)
|
||||
|
||||
return (
|
||||
"Here is an instruction for a multimodal LLM: '"
|
||||
f"{question}"
|
||||
"'. You need to grade if the response from the model follows each component of the instruction. "
|
||||
f"{components}"
|
||||
"The response is: '"
|
||||
f"{response}"
|
||||
"'. You need to score the response and be strict. The total score ranges from 0 to 10, "
|
||||
"depending on if the response follows the instruction. "
|
||||
f"{score}"
|
||||
"List scores of each component, and the total score in one sentence in this format: "
|
||||
"score of component 1: x/2, score of component 2: y/8, total score: z/10. Then explain your reasons."
|
||||
)
|
||||
|
||||
|
||||
def process_rawscore(component_type, raw_score):
|
||||
first_sentence = raw_score.split('.')[0].split(',')
|
||||
score_dict = {}
|
||||
for i in range(len(first_sentence) - 1):
|
||||
score_ = first_sentence[i].split(':')[1][1:].split('/')
|
||||
score = int(score_[0]) / int(score_[1])
|
||||
score_dict[component_type[i]] = score
|
||||
total_score_ = first_sentence[i + 1].split(':')[1][1:].split('/')
|
||||
total_score = int(total_score_[0]) / int(total_score_[1])
|
||||
score_dict['total_score'] = total_score
|
||||
return score_dict
|
||||
|
||||
|
||||
def get_score_dict(data, score_raw):
|
||||
cat_score_dict = {}
|
||||
for i in range(len(data)):
|
||||
try:
|
||||
cmp = data['component_type'][i][2:-2]
|
||||
cmp_list = cmp.split('\', \'')
|
||||
score_dict = process_rawscore(cmp_list, score_raw[i])
|
||||
for key, val in score_dict.items():
|
||||
if key not in cat_score_dict.keys():
|
||||
cat_score_dict[key] = [val]
|
||||
else:
|
||||
cat_score_dict[key].append(val)
|
||||
except:
|
||||
pass
|
||||
cat_score_dict_average = {}
|
||||
for key, val in cat_score_dict.items():
|
||||
cat_score_dict_average[key] = sum(val) / len(val)
|
||||
return cat_score_dict_average
|
||||
|
||||
|
||||
class MIABench(ImageBaseDataset):
|
||||
TYPE = 'VQA'
|
||||
|
||||
DATASET_URL = {
|
||||
'MIA-Bench': 'https://opencompass.openxlab.space/utils/VLMEval/Mia-Bench.tsv',
|
||||
}
|
||||
DATASET_MD5 = {
|
||||
'MIA-Bench': '0b9de595f4dd40af18a69b94d89aba82',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
judge_name = judge_kwargs.pop('model', 'gpt-4o')
|
||||
|
||||
model = build_judge(model=judge_name, **judge_kwargs)
|
||||
suffix = eval_file.split('.')[-1]
|
||||
|
||||
storage = eval_file.replace(f'.{suffix}', f'_{judge_name}.xlsx') # noqa: F841
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{judge_name}.pkl') # noqa: F841
|
||||
nproc = judge_kwargs.pop('nproc', 4) # noqa: F841
|
||||
|
||||
if not osp.exists(storage):
|
||||
data = load(eval_file)
|
||||
num_samples = len(data)
|
||||
lines = [data.loc[i] for i in range(num_samples)]
|
||||
prompts = [generate_prompt(line) for line in lines]
|
||||
org_data = MIABench('MIA-Bench').data
|
||||
img_map = {x: y for x, y in zip(org_data['index'], org_data['image'])}
|
||||
image_b64 = [img_map[idx] for idx in data['index']]
|
||||
indices = list(data['index'])
|
||||
mm_messages = [
|
||||
dict(message=[
|
||||
dict(type='text', value=prompt),
|
||||
dict(type='image', value=f'data:image/jpeg;base64,{b64}')
|
||||
])
|
||||
for prompt, b64 in zip(prompts, image_b64)
|
||||
]
|
||||
|
||||
res = {}
|
||||
if osp.exists(tmp_file):
|
||||
res = load(tmp_file)
|
||||
|
||||
jobs = {k: v for k, v in zip(indices, mm_messages) if k not in res}
|
||||
job_keys = list(jobs.keys())
|
||||
job_vals = [jobs[k] for k in job_keys]
|
||||
|
||||
resps = track_progress_rich(
|
||||
model.generate,
|
||||
job_vals,
|
||||
nproc=nproc,
|
||||
chunksize=nproc,
|
||||
keys=job_keys,
|
||||
save=tmp_file,
|
||||
)
|
||||
for k, resp in zip(job_keys, resps):
|
||||
res[k] = resp
|
||||
data['score_raw'] = [res[idx] for idx in indices]
|
||||
dump(data, storage)
|
||||
|
||||
goresult = load(storage)
|
||||
results = get_score_dict(goresult, goresult['score_raw'])
|
||||
result_pth = storage.replace('.xlsx', '_score.csv')
|
||||
results_pd = pd.DataFrame.from_dict(list(results.items()))
|
||||
dump(results_pd, result_pth)
|
||||
|
||||
return results
|
||||
455
eval_mm/vlmevalkit/vlmeval/dataset/mlvu.py
Normal file
@@ -0,0 +1,455 @@
|
||||
import huggingface_hub
|
||||
from huggingface_hub import snapshot_download
|
||||
from ..smp import *
|
||||
from .video_concat_dataset import ConcatVideoDataset
|
||||
from .video_base import VideoBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from ..utils import track_progress_rich
|
||||
import torchvision.transforms as T
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from decord import VideoReader, cpu
|
||||
import pandas as pd
|
||||
import imageio
|
||||
import cv2
|
||||
import zipfile
|
||||
import os
|
||||
import glob
|
||||
from .utils.mlvu import *
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
class MLVU(ConcatVideoDataset):
|
||||
def __init__(self, dataset='MLVU', nframe=0, fps=-1):
|
||||
self.DATASET_SETS[dataset] = ['MLVU_MCQ', 'MLVU_OpenEnded']
|
||||
self.type_data_dict = {
|
||||
'M-Avg':['plotQA', 'needle', 'ego', 'count', 'anomaly_reco', 'topic_reasoning'],
|
||||
'G-Avg':['sub_scene', 'summary']
|
||||
}
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['MLVU']
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
result = super().evaluate(eval_file=eval_file, **judge_kwargs)
|
||||
suffix = eval_file.split('.')[-1]
|
||||
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
for key in self.type_data_dict:
|
||||
result.loc[key] = 0.0
|
||||
for name, item in result.iterrows():
|
||||
if name in self.type_data_dict[key]:
|
||||
result.loc[key, 'success'] += item['success']
|
||||
result.loc[key, 'overall'] += item['overall']
|
||||
if key == 'G-Avg':
|
||||
result.loc[key, 'acc'] = round(
|
||||
result.loc[key, 'success'] / result.loc[key, 'overall'], 2
|
||||
)
|
||||
else:
|
||||
result.loc[key, 'acc'] = round(
|
||||
result.loc[key, 'success'] / result.loc[key, 'overall'] * 100, 1
|
||||
)
|
||||
result = result.reset_index().rename(columns={'index': 'task'})
|
||||
dump(result, score_file)
|
||||
return result
|
||||
|
||||
|
||||
class MLVU_MCQ(VideoBaseDataset):
|
||||
|
||||
MD5 = 'bb5c37e7cf8d43fc9a25c23d2b4633f5'
|
||||
BASE_SYS = 'Carefully watch this video and pay attention to every detail. '
|
||||
SYS = BASE_SYS + 'Based on your observations, select the best option that accurately addresses the question.'
|
||||
TYPE = 'Video-MCQ'
|
||||
|
||||
def __init__(self, dataset='MLVU_MCQ', nframe=0, fps=-1):
|
||||
self.type_data_list = {
|
||||
'plotQA': ('1_plotQA.json', './MLVU/video/1_plotQA', 'MCQ'),
|
||||
'needle': ('2_needle.json', './MLVU/video/2_needle', 'MCQ'),
|
||||
'ego': ('3_ego.json', './MLVU/video/3_ego', 'MCQ'),
|
||||
'count': ('4_count.json', './MLVU/video/4_count', 'MCQ'),
|
||||
'order': ('5_order.json', './MLVU/video/5_order', 'MCQ'),
|
||||
'anomaly_reco': ('6_anomaly_reco.json', './MLVU/video/6_anomaly_reco', 'MCQ'),
|
||||
'topic_reasoning': ('7_topic_reasoning.json', './MLVU/video/7_topic_reasoning', 'MCQ'),
|
||||
}
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['MLVU_MCQ']
|
||||
|
||||
def prepare_dataset(self, dataset_name='MLVU_MCQ', repo_id='MLVU/MVLU'):
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
|
||||
if not os.path.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.MD5:
|
||||
return False
|
||||
|
||||
data = load(data_file)
|
||||
for idx, item in data.iterrows():
|
||||
if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
|
||||
return False
|
||||
return True
|
||||
|
||||
if modelscope_flag_set():
|
||||
repo_id = "AI-ModelScope/MLVU"
|
||||
|
||||
cache_path = get_cache_path(repo_id)
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
def generate_tsv(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if os.path.exists(data_file) and md5(data_file) == self.MD5:
|
||||
return
|
||||
json_data_dir = os.path.join(dataset_path, 'MLVU', 'json')
|
||||
self.data_list = []
|
||||
for k, v in self.type_data_list.items():
|
||||
with open(os.path.join(json_data_dir, v[0]), 'r') as f:
|
||||
json_data = json.load(f)
|
||||
for data in json_data:
|
||||
self.data_list.append({
|
||||
'task_type': k,
|
||||
'prefix': v[1],
|
||||
'duration': data['duration'],
|
||||
'video': data['video'],
|
||||
'question': data['question'],
|
||||
'answer': data['answer'],
|
||||
'candidates': data['candidates'],
|
||||
})
|
||||
|
||||
data_df = pd.DataFrame(self.data_list)
|
||||
data_df = data_df.assign(index=range(len(data_df)))
|
||||
data_df.to_csv(data_file, sep='\t', index=False)
|
||||
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
||||
else:
|
||||
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
|
||||
huggingface_hub.login(hf_token)
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
||||
|
||||
generate_tsv(dataset_path)
|
||||
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
return dict(root=dataset_path, data_file=data_file)
|
||||
|
||||
def qa_template(self, data):
|
||||
question = f"Question: {data['question']}\n"
|
||||
question += 'Options:\n'
|
||||
answer = data['answer']
|
||||
answer_idx = -1
|
||||
for idx, c in enumerate(eval(data['candidates'])):
|
||||
question += f"({chr(ord('A') + idx)}) {c}\n"
|
||||
if c == answer:
|
||||
answer_idx = idx
|
||||
question = question.rstrip()
|
||||
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
|
||||
return question, answer
|
||||
|
||||
def save_video_frames(self, line):
|
||||
suffix = line['video'].split('.')[-1]
|
||||
video = line['video'].replace(f'.{suffix}','')
|
||||
vid_path = osp.join(self.data_root, line['prefix'], line['video'])
|
||||
vid = decord.VideoReader(vid_path)
|
||||
video_info = {
|
||||
'fps': vid.get_avg_fps(),
|
||||
'n_frames': len(vid),
|
||||
}
|
||||
if self.nframe > 0 and self.fps < 0:
|
||||
step_size = len(vid) / (self.nframe + 1)
|
||||
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
|
||||
frame_paths = self.frame_paths(video)
|
||||
elif self.fps > 0:
|
||||
# not constrained by num_frames, get frames by fps
|
||||
total_duration = video_info['n_frames'] / video_info['fps']
|
||||
required_frames = int(total_duration * self.fps)
|
||||
step_size = video_info['fps'] / self.fps
|
||||
indices = [int(i * step_size) for i in range(required_frames)]
|
||||
frame_paths = self.frame_paths_fps(video, len(indices))
|
||||
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
|
||||
if not flag:
|
||||
images = [vid[i].asnumpy() for i in indices]
|
||||
images = [Image.fromarray(arr) for arr in images]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth):
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths
|
||||
|
||||
def save_video_into_images(self, line):
|
||||
frame_paths = self.save_video_frames(line)
|
||||
return frame_paths
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question, answer = self.qa_template(line)
|
||||
message = [dict(type='text', value=self.SYS, role='system')]
|
||||
message.append(dict(type='text', value=question))
|
||||
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
||||
if video_llm:
|
||||
message.append(dict(type='video', value=video_path))
|
||||
else:
|
||||
img_frame_paths = self.save_video_into_images(line)
|
||||
for im in img_frame_paths:
|
||||
message.append(dict(type='image', value=im))
|
||||
message.append(dict(type='text', value='\nOnly give the best option.'))
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
||||
|
||||
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
||||
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
||||
|
||||
if not osp.exists(score_file):
|
||||
model = judge_kwargs.setdefault('model', 'chatgpt-0125')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
||||
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
||||
|
||||
data = load(eval_file)
|
||||
data_un = data[~pd.isna(data['prediction'])]
|
||||
|
||||
for idx in data['index']:
|
||||
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
||||
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
|
||||
options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
|
||||
answer_idx = -1
|
||||
for id, c in enumerate(options):
|
||||
if c == ans:
|
||||
answer_idx = id
|
||||
ans = f"({chr(ord('A') + answer_idx)}) {ans}"
|
||||
input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
|
||||
for id, option_content in enumerate(eval(input_item['candidates'])):
|
||||
input_item[chr(ord('A') + id)] = option_content
|
||||
if option_content == input_item['answer']:
|
||||
input_item['answer'] = chr(ord('A') + id)
|
||||
|
||||
if FAIL_MSG in pred:
|
||||
data.loc[idx, 'score'] = -1
|
||||
else:
|
||||
data.loc[idx, 'score'] = int(check_ans_with_model(
|
||||
pred, ans, model,
|
||||
input_item,
|
||||
'MLVU_MCQ'
|
||||
))
|
||||
|
||||
rejected = [x for x in data['score'] if x == -1]
|
||||
|
||||
print(
|
||||
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
||||
f'failed to obtain the score for another {len(rejected)} questions. '
|
||||
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
||||
)
|
||||
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
return rating
|
||||
|
||||
|
||||
class MLVU_OpenEnded(VideoBaseDataset):
|
||||
|
||||
MD5 = 'cee573a3627c6ac434ded704c60511ba'
|
||||
BASE_SYS = 'Carefully watch this video and pay attention to every detail. '
|
||||
SYS = BASE_SYS + 'Based on your observations, answer the given questions.'
|
||||
TYPE = 'Video-VQA'
|
||||
|
||||
def __init__(self, dataset='MLVU_OpenEnded', nframe=0, fps=-1):
|
||||
self.type_data_list = {
|
||||
'sub_scene': ('8_sub_scene.json', './MLVU/video/8_sub_scene', 'VQA'),
|
||||
'summary': ('9_summary.json', './MLVU/video/9_summary', 'VQA')
|
||||
}
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['MLVU_OpenEnded']
|
||||
|
||||
def prepare_dataset(self, dataset_name='MLVU_OpenEnded', repo_id='MLVU/MVLU'):
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
|
||||
if not os.path.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.MD5:
|
||||
return False
|
||||
|
||||
data = load(data_file)
|
||||
for idx, item in data.iterrows():
|
||||
if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
|
||||
return False
|
||||
return True
|
||||
|
||||
if modelscope_flag_set():
|
||||
repo_id = "AI-ModelScope/MLVU"
|
||||
|
||||
cache_path = get_cache_path(repo_id)
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
def generate_tsv(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if os.path.exists(data_file) and md5(data_file) == self.MD5:
|
||||
return
|
||||
json_data_dir = os.path.join(dataset_path, 'MLVU', 'json')
|
||||
self.data_list = []
|
||||
for k, v in self.type_data_list.items():
|
||||
with open(os.path.join(json_data_dir, v[0]), 'r') as f:
|
||||
json_data = json.load(f)
|
||||
for data in json_data:
|
||||
self.data_list.append({
|
||||
'task_type': k,
|
||||
'prefix': v[1],
|
||||
'duration': data['duration'],
|
||||
'video': data['video'],
|
||||
'question': data['question'],
|
||||
'answer': data['answer'],
|
||||
'scoring_points': data['scoring_points'] if 'scoring_points' in data else ''
|
||||
})
|
||||
|
||||
data_df = pd.DataFrame(self.data_list)
|
||||
data_df = data_df.assign(index=range(len(data_df)))
|
||||
data_df.to_csv(data_file, sep='\t', index=False)
|
||||
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
||||
else:
|
||||
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
|
||||
huggingface_hub.login(hf_token)
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
||||
|
||||
generate_tsv(dataset_path)
|
||||
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
return dict(root=dataset_path, data_file=data_file)
|
||||
|
||||
def qa_template(self, data):
|
||||
question = f"{data['question']}"
|
||||
answer = data['answer']
|
||||
return question, answer
|
||||
|
||||
def save_video_frames(self, line):
|
||||
suffix = line['video'].split('.')[-1]
|
||||
video = line['video'].replace(f'.{suffix}','')
|
||||
vid_path = osp.join(self.data_root, line['prefix'], line['video'])
|
||||
vid = decord.VideoReader(vid_path)
|
||||
video_info = {
|
||||
'fps': vid.get_avg_fps(),
|
||||
'n_frames': len(vid),
|
||||
}
|
||||
if self.nframe > 0 and self.fps < 0:
|
||||
step_size = len(vid) / (self.nframe + 1)
|
||||
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
|
||||
frame_paths = self.frame_paths(video)
|
||||
elif self.fps > 0:
|
||||
# not constrained by num_frames, get frames by fps
|
||||
total_duration = video_info['n_frames'] / video_info['fps']
|
||||
required_frames = int(total_duration * self.fps)
|
||||
step_size = video_info['fps'] / self.fps
|
||||
indices = [int(i * step_size) for i in range(required_frames)]
|
||||
frame_paths = self.frame_paths_fps(video, len(indices))
|
||||
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
|
||||
if not flag:
|
||||
images = [vid[i].asnumpy() for i in indices]
|
||||
images = [Image.fromarray(arr) for arr in images]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth):
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths
|
||||
|
||||
def save_video_into_images(self, line):
|
||||
frame_paths = self.save_video_frames(line)
|
||||
return frame_paths
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question, answer = self.qa_template(line)
|
||||
message = [dict(type='text', value=self.SYS, role='system')]
|
||||
message.append(dict(type='text', value=question))
|
||||
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
||||
if video_llm:
|
||||
message.append(dict(type='video', value=video_path))
|
||||
else:
|
||||
img_frame_paths = self.save_video_into_images(line)
|
||||
for im in img_frame_paths:
|
||||
message.append(dict(type='image', value=im))
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
|
||||
model = judge_kwargs['model'] if 'model' in judge_kwargs else judge_kwargs.setdefault('model', 'gpt-4-0125')
|
||||
if model != 'gpt-4-0125':
|
||||
print('MLVU Open Ended default using gpt-4-0125! So judge model is changed to gpt-4-0125')
|
||||
judge_kwargs['model'] = 'gpt-4-0125'
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
if not osp.exists(score_file):
|
||||
data = load(eval_file)
|
||||
model_dict = {
|
||||
'sub_scene': build_judge(system_prompt=system_prompt_sub_scene, **judge_kwargs),
|
||||
'summary': build_judge(system_prompt=system_prompt_summary, **judge_kwargs)
|
||||
}
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
tups = [(model_dict[line['task_type']], line) for line in lines]
|
||||
indices = [line['index'] for line in lines]
|
||||
|
||||
ans = {}
|
||||
if osp.exists(tmp_file):
|
||||
ans = load(tmp_file)
|
||||
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
||||
indices = [i for i in indices if i not in ans]
|
||||
|
||||
if len(indices):
|
||||
_ = track_progress_rich(
|
||||
MLVU_OpenEnded_generate,
|
||||
tups,
|
||||
nproc=nproc,
|
||||
chunksize=nproc,
|
||||
keys=indices,
|
||||
save=tmp_file,
|
||||
)
|
||||
ans = load(tmp_file)
|
||||
data = MLVU_OpenEnded_extract(ans, data)
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
return rating
|
||||
256
eval_mm/vlmevalkit/vlmeval/dataset/mmbench_video.py
Normal file
@@ -0,0 +1,256 @@
|
||||
from huggingface_hub import snapshot_download
|
||||
from ..smp import *
|
||||
from .video_base import VideoBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from ..utils import track_progress_rich
|
||||
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
def unwrap_hf_pkl(pth, suffix='.mp4'):
|
||||
base_dir = os.path.join(pth, 'video_pkl/')
|
||||
target_dir = os.path.join(pth, 'video/')
|
||||
pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
|
||||
pickle_files.sort()
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
for pickle_file in pickle_files:
|
||||
with open(pickle_file, 'rb') as file:
|
||||
video_data = pickle.load(file)
|
||||
# For each video file in the pickle file, write its contents to a new mp4 file
|
||||
for video_name, video_content in video_data.items():
|
||||
output_path = os.path.join(target_dir, f'{video_name}{suffix}')
|
||||
with open(output_path, 'wb') as output_file:
|
||||
output_file.write(video_content)
|
||||
print('The video file has been restored and stored from the pickle file.')
|
||||
else:
|
||||
print('The video file already exists.')
|
||||
|
||||
|
||||
class MMBenchVideo(VideoBaseDataset):
|
||||
|
||||
MD5 = '98f7df3eb1007fc375ea6fe88a98e2ff'
|
||||
SYS = 'You are an AI assistant responsible for answering questions about videos.'
|
||||
FRAMES_TMPL_PACK = """
|
||||
You will be provided with {} separate frames uniformly sampled from a video, \
|
||||
the frames are provided in chronological order of the video.
|
||||
Please analyze these images and provide the answer / answers to the \
|
||||
following question / questions about the video content.
|
||||
If multiple questions are provided (with indices I1, I2, I3, ...), \
|
||||
you should organize your answers in the following json format:
|
||||
{{
|
||||
'I1': 'Answer to Question I1',
|
||||
'I2': 'Answer to Question I2',
|
||||
...
|
||||
}}
|
||||
Otherwise, please directly reply with your response to the only question.
|
||||
Even if the information in these separate frames is not enough to give an answer,
|
||||
PLEASE GIVE A RESPONSE TO EACH OF THE QUESTIONS IN THE FORMAT DESCRIBED ABOVE.
|
||||
"""
|
||||
|
||||
FRAMES_TMPL_NOPACK = """
|
||||
You will be provided with {} separate frames uniformly sampled from a video, \
|
||||
the frames are provided in chronological order of the video.
|
||||
Please analyze these images and provide the answer to the question about the video content.
|
||||
Please directly reply with your response to the only question.
|
||||
"""
|
||||
|
||||
TYPE = 'Video-VQA'
|
||||
|
||||
def __init__(self, dataset='MMBench-Video', pack=False, nframe=0, fps=-1):
|
||||
super().__init__(dataset=dataset, pack=pack, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['MMBench-Video']
|
||||
|
||||
def prepare_dataset(self, dataset_name='MMBench-Video', repo_id='opencompass/MMBench-Video'):
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if md5(data_file) != self.MD5:
|
||||
return False
|
||||
data = load(data_file)
|
||||
for video_pth in data['video_path']:
|
||||
if not osp.exists(osp.join(pth, video_pth)):
|
||||
return False
|
||||
return True
|
||||
|
||||
cache_path = get_cache_path(repo_id)
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
||||
else:
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
||||
unwrap_hf_pkl(dataset_path)
|
||||
self.video_path = osp.join(dataset_path, 'video/')
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
|
||||
return dict(data_file=data_file, root=osp.join(dataset_path, 'video'))
|
||||
|
||||
def build_prompt_pack(self, line):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
video = self.videos[line]
|
||||
elif isinstance(line, pd.Series):
|
||||
video = line['video']
|
||||
elif isinstance(line, str):
|
||||
video = line
|
||||
|
||||
frames = self.save_video_frames(video)
|
||||
sub = self.data[self.data['video'] == video]
|
||||
sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(len(frames))
|
||||
message = [dict(type='text', value=sys_prompt)]
|
||||
for im in frames:
|
||||
message.append(dict(type='image', value=im))
|
||||
nq = len(sub)
|
||||
prompt = 'Questions: \n{}\nAnswers: \n'
|
||||
qs = {int(sub.iloc[i]['index']): sub.iloc[i]['question'] for i in range(nq)}
|
||||
prompt = prompt.format(json.dumps(qs))
|
||||
message.append(dict(type='text', value=prompt))
|
||||
return message
|
||||
|
||||
def build_prompt_nopack(self, line, video_llm):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
if video_llm:
|
||||
question = line['question']
|
||||
prefix, video_idx_path = os.path.split(line['video_path'])
|
||||
message = [dict(type='text', value=question)]
|
||||
message.append(dict(type='video', value=os.path.join(self.video_path, video_idx_path)))
|
||||
return message
|
||||
else:
|
||||
frames = self.save_video_frames(line['video'])
|
||||
sys_prompt = self.FRAMES_TMPL_NOPACK.format(len(frames))
|
||||
message = [dict(type='text', value=sys_prompt)]
|
||||
for im in frames:
|
||||
message.append(dict(type='image', value=im))
|
||||
prompt = 'Question: {}\nAnswer: '.format(line['question'])
|
||||
message.append(dict(type='text', value=prompt))
|
||||
return message
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if self.pack and not video_llm:
|
||||
return self.build_prompt_pack(line)
|
||||
else:
|
||||
return self.build_prompt_nopack(line, video_llm)
|
||||
|
||||
@staticmethod
|
||||
def remove_side_quote(s, syms=[',', '"', "'"]):
|
||||
if np.all([x in syms for x in s]):
|
||||
return ''
|
||||
while s[0] in syms:
|
||||
s = s[1:]
|
||||
while s[-1] in syms:
|
||||
s = s[:-1]
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def robust_json_load(s):
|
||||
try:
|
||||
jsons = list(extract_json_objects(s))
|
||||
assert len(jsons) == 1
|
||||
return jsons[0]
|
||||
except:
|
||||
if '{' in s and s.find('{') == s.rfind('{'):
|
||||
sub_str = s[s.find('{') + 1:].strip()
|
||||
lines = sub_str.split('\n')
|
||||
res = {}
|
||||
for l in lines:
|
||||
l = l.strip()
|
||||
if ': ' in l:
|
||||
key = l.split(': ')[0].strip()
|
||||
val = l.split(': ')[1].strip()
|
||||
key = MMBenchVideo.remove_side_quote(key)
|
||||
val = MMBenchVideo.remove_side_quote(val)
|
||||
if len(key) and len(val):
|
||||
res[key] = val
|
||||
return res
|
||||
return None
|
||||
|
||||
def load_pack_answers(self, data_raw):
|
||||
vstats = defaultdict(lambda: 0)
|
||||
data = defaultdict(lambda: {})
|
||||
|
||||
for k in data_raw:
|
||||
ans = data_raw[k].strip()
|
||||
if FAIL_MSG in ans:
|
||||
vstats['GEN_FAIL'] += 1
|
||||
continue
|
||||
res = self.robust_json_load(ans)
|
||||
if res is not None:
|
||||
data[k] = res
|
||||
vstats['PARSE_OK'] += 1
|
||||
else:
|
||||
vstats['PARSE_FAIL'] += 1
|
||||
|
||||
# return data
|
||||
meta = cp.deepcopy(self.data)
|
||||
lt = len(meta)
|
||||
prediction = []
|
||||
for i in range(lt):
|
||||
line = meta.iloc[i]
|
||||
vid = line['video']
|
||||
idx = str(line['index'])
|
||||
prediction.append(data[vid][idx] if idx in data[vid] else None)
|
||||
meta['prediction'] = prediction
|
||||
vstats['VALIDQ'] = len([x for x in prediction if x is not None])
|
||||
vstats['INVALIDQ'] = len([x for x in prediction if x is None])
|
||||
return meta, vstats
|
||||
|
||||
# It returns a dictionary
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.mmbench_video import get_dimension_rating, system_prompt, build_prompt
|
||||
|
||||
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
||||
judge = judge_kwargs['model']
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
tmp_file = eval_file.replace('.xlsx', f'_{judge}_tmp.pkl')
|
||||
tgt_file = eval_file.replace('.xlsx', f'_{judge}_rating.json')
|
||||
score_file = eval_file.replace('.xlsx', f'_{judge}_score.xlsx')
|
||||
|
||||
model = build_judge(system_prompt=system_prompt, **judge_kwargs)
|
||||
assert model.working(), 'MMBench-Video evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE
|
||||
|
||||
if not osp.exists(score_file):
|
||||
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
||||
res = {k: v for k, v in res.items() if model.fail_msg not in v}
|
||||
|
||||
data = load(eval_file)
|
||||
data_un = data[~data['index'].isin(res)]
|
||||
data_un = data_un[~pd.isna(data_un['prediction'])]
|
||||
lt = len(data_un)
|
||||
prompts = [build_prompt(data_un.iloc[i]) for i in range(lt)]
|
||||
indices = [data_un.iloc[i]['index'] for i in range(lt)]
|
||||
|
||||
if len(prompts):
|
||||
_ = track_progress_rich(
|
||||
model.generate,
|
||||
prompts,
|
||||
keys=indices,
|
||||
save=tmp_file,
|
||||
nproc=nproc,
|
||||
chunksize=nproc
|
||||
)
|
||||
score_map = load(tmp_file)
|
||||
data['score'] = [score_map[idx] if idx in score_map else -1 for idx in data['index']]
|
||||
rejected = [x for x in score_map.values() if FAIL_MSG in x]
|
||||
data['score'] = [int(x) if istype(x, int) else -1 for x in data['score']]
|
||||
print(
|
||||
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(score_map)} questions, '
|
||||
f'failed to obtain the score for another {len(rejected)} questions. '
|
||||
f'Those questions will be counted as 0 score in ALL rating, and will not be counted in VALID rating.'
|
||||
)
|
||||
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
dump(rating, tgt_file)
|
||||
return rating
|
||||
69
eval_mm/vlmevalkit/vlmeval/dataset/mmgenbench.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from abc import abstractmethod
|
||||
from ..smp import *
|
||||
from .image_base import ImageBaseDataset
|
||||
|
||||
|
||||
class MMGenBench(ImageBaseDataset):
|
||||
|
||||
prompt_list = [
|
||||
"""
|
||||
# Role
|
||||
You are an expert in the field of image understanding, focusing on the \
|
||||
understanding of images and generating the image caption-prompt.
|
||||
|
||||
# Definition Explanation
|
||||
image caption-prompt: Refers to the caption or description of an image, \
|
||||
used to provide to a Text-to-Image model to generate a new image.
|
||||
Text-to-Image model: Can generate a new image based on the provided image \
|
||||
caption-prompt, such as stable diffusion 3, flux, and other image generation models.
|
||||
|
||||
# Task Description
|
||||
Generate an image caption-prompt based on the input image.
|
||||
|
||||
# Key Points and Requirements
|
||||
1. Accurately understand the input image and precisely generate an image caption-prompt.
|
||||
2. The generated image caption-prompt, when provided to the Text-to-Image model, requires the \
|
||||
Text-to-Image model to generate a new image that is as consistent as possible with the input image.
|
||||
3. The generated image caption-prompt must conform to the preferences of the Text-to-Image model.
|
||||
4. The generated image caption-prompt should describe the input image in as much \
|
||||
detail as possible, and it should be between 20 to 60 words.
|
||||
|
||||
# Output Format
|
||||
A string, that is the image caption-prompt. No extra output needed.
|
||||
"""
|
||||
]
|
||||
TYPE = 'GenerateImgPrompt'
|
||||
DATASET_URL = {
|
||||
'MMGenBench-Test': 'https://huggingface.co/datasets/lerogo/MMGenBench/resolve/main/MMGenBench-Test.tsv',
|
||||
'MMGenBench-Domain': 'https://huggingface.co/datasets/lerogo/MMGenBench/resolve/main/MMGenBench-Domain.tsv',
|
||||
}
|
||||
PROMPT_MAP = {
|
||||
'MMGenBench-Test': prompt_list[0],
|
||||
'MMGenBench-Domain': prompt_list[0],
|
||||
}
|
||||
DATASET_MD5 = {
|
||||
'MMGenBench-Test': "94f8dac6bbf7c20be403f99adeaa73da",
|
||||
'MMGenBench-Domain': "5c10daf6e2c5f08bdfb0701aa6db86bb",
|
||||
}
|
||||
|
||||
def __init__(self, dataset='MMGenBench', **kwargs):
|
||||
super().__init__(dataset, **kwargs)
|
||||
warnings.warn('This dataset is for inference only and does not support direct output of evaluation results.\n')
|
||||
warnings.warn('Please refer to "https://github.com/lerogo/MMGenBench" for more evaluation information.\n')
|
||||
|
||||
def load_data(self, dataset):
|
||||
data = super().load_data(dataset)
|
||||
if 'question' not in data:
|
||||
data['question'] = [(
|
||||
self.PROMPT_MAP[dataset]
|
||||
)] * len(data)
|
||||
return data
|
||||
|
||||
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
|
||||
@abstractmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
warnings.warn('This evaluation method is not supported.\n')
|
||||
warnings.warn('Please refer to "https://github.com/lerogo/MMGenBench" for more evaluation information.\n')
|
||||
return None
|
||||
584
eval_mm/vlmevalkit/vlmeval/dataset/mmlongbench.py
Normal file
@@ -0,0 +1,584 @@
|
||||
import re
|
||||
import math
|
||||
from urllib.request import urlopen
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from vlmeval.dataset.utils import build_judge, levenshtein_distance
|
||||
from vlmeval.smp import *
|
||||
from .image_base import ImageBaseDataset
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
def get_gpt4_ICE():
|
||||
example_1 = """
|
||||
---
|
||||
Question: List the primary questions asked about the services in this report.
|
||||
Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n
|
||||
1. Is the service safe?\n
|
||||
2. Is the service effective?\n
|
||||
3. Is the service caring?\n
|
||||
4. Is the service responsive?\n
|
||||
5. Is the service well-led?
|
||||
Extracted answer: [
|
||||
'Is the servife safe?',
|
||||
'Is the service effective',
|
||||
'Is the serve caring?',
|
||||
'Is the service responsive?',
|
||||
'Is the service well-led?'
|
||||
]
|
||||
Answer format: List\n
|
||||
"""
|
||||
|
||||
example_2 = """
|
||||
---
|
||||
Question: How many regulations of the HSCA 2008 are breached in all according to this report?
|
||||
Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities)
|
||||
Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and
|
||||
improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11:
|
||||
Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17:
|
||||
Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9.
|
||||
Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement
|
||||
the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing,
|
||||
safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to
|
||||
notify the CQC of incidents.
|
||||
Extracted answer: 10
|
||||
Answer format: Integer\n
|
||||
"""
|
||||
|
||||
example_3 = """
|
||||
---
|
||||
Question: According to the survey that is the percentage of Chinese who are paying more or
|
||||
about the same attention to politics after Trump's election?
|
||||
Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying
|
||||
more or about the same attention to politics after Trump's election. The report focuses primarily on American
|
||||
demographics and does not include specific details about the Chinese population in relation to this question. If
|
||||
you need information about a different demographic or a summary of the findings from the American demographic,
|
||||
I can certainly help with that!
|
||||
Extracted answer: Not answerable
|
||||
Answer format: String\n
|
||||
"""
|
||||
|
||||
example_4 = """
|
||||
---
|
||||
Question: How many quotations from male respondent over 50 years old are included in this report?
|
||||
Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the
|
||||
text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be
|
||||
able to help you with your question.
|
||||
Extracted answer: Fail to answer
|
||||
Answer format: String\n
|
||||
"""
|
||||
|
||||
return [example_1, example_2, example_3, example_4]
|
||||
|
||||
|
||||
def build_mmlongbench_gpt4_prompt(line):
|
||||
task_description = """
|
||||
Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis.
|
||||
- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List.
|
||||
If you find the analysis the question can not be answered from the given documents, type "Not answerable".
|
||||
Exception: If the analysis only tells you that it can not read/understand the images or documents,
|
||||
type "Fail to answer".
|
||||
- Please make your response as concise as possible. Also note that your response should be formatted as below:
|
||||
```
|
||||
Extracted answer: [answer]
|
||||
Answer format: [answer format]
|
||||
```
|
||||
Please read the following example, then extract the answer from the model response
|
||||
and type it at the end of the prompt.\n
|
||||
"""
|
||||
question = line['question']
|
||||
prediction = str(line['prediction'])
|
||||
prompt = task_description
|
||||
examples = get_gpt4_ICE()
|
||||
for example in examples:
|
||||
prompt += example
|
||||
prompt += '---\nQuestion:' + question + '\n'
|
||||
prompt += 'Analysis: ' + prediction
|
||||
return prompt
|
||||
|
||||
|
||||
def anls_compute(groundtruth, prediction, threshold=0.5):
|
||||
dist = levenshtein_distance(groundtruth, prediction)
|
||||
length = max(len(groundtruth.upper()), len(prediction.upper()))
|
||||
value = 0.0 if length == 0 else float(dist) / float(length)
|
||||
anls = 1.0 - value
|
||||
if anls <= threshold:
|
||||
anls = 0.0
|
||||
return anls
|
||||
|
||||
|
||||
def is_float_equal(reference, prediction, include_percentage: bool = False, is_close: float = False) -> bool:
|
||||
def get_precision(gt_ans: float) -> int:
|
||||
precision = 3
|
||||
if '.' in str(gt_ans):
|
||||
precision = len(str(gt_ans).split('.')[-1])
|
||||
return precision
|
||||
|
||||
reference = float(str(reference).strip().rstrip('%').strip())
|
||||
try:
|
||||
prediction = float(str(prediction).strip().rstrip('%').strip())
|
||||
except:
|
||||
return False
|
||||
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if math.isclose(item, prediction, rel_tol=0.01):
|
||||
return True
|
||||
precision = max(min(get_precision(prediction), get_precision(item)), 2)
|
||||
if round(prediction, precision) == round(item, precision):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
def get_clean_string(s):
|
||||
s = str(s).lower().strip()
|
||||
if s.endswith('mile'):
|
||||
s.rstrip('mile').strip()
|
||||
if s.endswith('miles'):
|
||||
s.rstrip('miles').strip()
|
||||
if s.endswith('million'):
|
||||
s.rstrip('million').strip()
|
||||
# remove parenthesis
|
||||
s = re.sub(r'\s*\([^)]*\)', '', s).strip()
|
||||
# remove quotes
|
||||
s = re.sub(r"^['\"]|['\"]$", '', s).strip()
|
||||
s = s.strip().lstrip('$').strip()
|
||||
s = s.strip().rstrip('%').strip()
|
||||
return s
|
||||
|
||||
|
||||
def is_exact_match(s):
|
||||
flag = False
|
||||
# Website
|
||||
if 'https://' in s:
|
||||
flag = True
|
||||
# code file
|
||||
if s.endswith('.py') or s.endswith('ipynb'):
|
||||
flag = True
|
||||
if s.startswith('page'):
|
||||
flag = True
|
||||
# telephone number
|
||||
if re.fullmatch(r'\b\d+(-\d+|\s\d+)?\b', s):
|
||||
flag = True
|
||||
# time
|
||||
if 'a.m.' in s or 'p.m.' in s:
|
||||
flag = True
|
||||
# YYYY-MM-DD
|
||||
if re.fullmatch(r'\b\d{4}[-\s]\d{2}[-\s]\d{2}\b', s):
|
||||
flag = True
|
||||
# YYYY-MM
|
||||
if re.fullmatch(r'\b\d{4}[-\s]\d{2}\b', s):
|
||||
flag = True
|
||||
# Email address
|
||||
if re.fullmatch(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', s):
|
||||
flag = True
|
||||
return flag
|
||||
|
||||
|
||||
def isfloat(num):
|
||||
try:
|
||||
float(num)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def get_font():
|
||||
try:
|
||||
truetype_url = "http://opencompass.openxlab.space/utils/Fonts/SimHei.ttf"
|
||||
ff = urlopen(truetype_url)
|
||||
font = ImageFont.truetype(ff, size=40)
|
||||
except Exception as e:
|
||||
logging.warning(f'{type(e)}: {e}')
|
||||
logging.warning("Fail to download the font. Use the default one.")
|
||||
font = ImageFont.load_default(size=40)
|
||||
return font
|
||||
|
||||
|
||||
def frame2img(img_path_list, font, save_path=None, idx_start=0):
|
||||
imgs = [Image.open(img_path) for img_path in img_path_list]
|
||||
|
||||
new_imgs = []
|
||||
for img in imgs:
|
||||
w, h = img.size
|
||||
scale = w / h
|
||||
if w > h:
|
||||
new_w = 560 * 2
|
||||
new_h = int(560 * 2 / scale)
|
||||
else:
|
||||
new_w = int(560 * 2 * scale)
|
||||
new_h = 560 * 2
|
||||
img = transforms.functional.resize(img, [new_h, new_w],)
|
||||
new_imgs.append(img)
|
||||
imgs = new_imgs
|
||||
new_w = 0
|
||||
new_h = 0
|
||||
pad = 40
|
||||
if w > h:
|
||||
for im in imgs:
|
||||
w, h = im.size
|
||||
new_w = max(new_w, w)
|
||||
new_h += h + 10 + pad
|
||||
new_img = Image.new("RGB", (new_w, new_h), "white")
|
||||
draw = ImageDraw.Draw(new_img)
|
||||
curr_h = 0
|
||||
for idx, im in enumerate(imgs):
|
||||
w, h = im.size
|
||||
new_img.paste(im, (0, pad + curr_h))
|
||||
draw.text((0, curr_h), f"<IMAGE {idx+idx_start}>", font=font, fill="black")
|
||||
if idx + 1 < len(imgs):
|
||||
draw.line([(0, pad + curr_h + h + 5), (new_w, pad + curr_h + h + 5)], fill='black', width=2)
|
||||
curr_h += h + 10 + pad
|
||||
else:
|
||||
for im in imgs:
|
||||
w, h = im.size
|
||||
new_w += w + 10
|
||||
new_h = max(new_h, h)
|
||||
new_h += pad
|
||||
new_img = Image.new('RGB', (new_w, new_h), 'white')
|
||||
draw = ImageDraw.Draw(new_img)
|
||||
curr_w = 0
|
||||
for idx, im in enumerate(imgs):
|
||||
w, h = im.size
|
||||
new_img.paste(im, (curr_w, pad))
|
||||
draw.text((curr_w, 0), f"<IMAGE {idx+idx_start}>", font=font, fill='black')
|
||||
if idx + 1 < len(imgs):
|
||||
draw.line([(curr_w + w + 5, 0), (curr_w + w + 5, new_h)], fill='black', width=2)
|
||||
curr_w += w + 10
|
||||
|
||||
if save_path is not None:
|
||||
new_img.save(save_path)
|
||||
|
||||
return new_img
|
||||
|
||||
|
||||
def concat_images(image_list, max_concat=1, column_num=1):
|
||||
concatenated_images = []
|
||||
if column_num == -1:
|
||||
MAX_COLUMN_NUM = 20
|
||||
max_concat = 1
|
||||
while len(image_list) / max_concat > MAX_COLUMN_NUM:
|
||||
max_concat += 1
|
||||
interval = max(math.ceil(len(image_list) / max_concat), 1)
|
||||
for i in range(0, len(image_list), interval):
|
||||
batch_images = image_list[i:i + interval]
|
||||
concatenated_image = frame2img(batch_images, font=get_font(), idx_start=i)
|
||||
concatenated_images.append(concatenated_image)
|
||||
else:
|
||||
interval = max(math.ceil(len(image_list) / max_concat), 1)
|
||||
for i in range(0, len(image_list), interval):
|
||||
batch_images = [Image.open(filename) for filename in image_list[i:i + interval]]
|
||||
if column_num == 1:
|
||||
total_height = batch_images[0].height * len(batch_images)
|
||||
else:
|
||||
total_height = batch_images[0].height * ((len(batch_images) - 1) // column_num + 1)
|
||||
concatenated_image = Image.new('RGB', (batch_images[0].width * column_num, total_height), 'white')
|
||||
|
||||
x_offset, y_offset = 0, 0
|
||||
for count, image in enumerate(batch_images):
|
||||
concatenated_image.paste(image, (x_offset, y_offset))
|
||||
x_offset += image.width
|
||||
if (count + 1) % column_num == 0:
|
||||
y_offset += image.height
|
||||
x_offset = 0
|
||||
concatenated_images.append(concatenated_image)
|
||||
return concatenated_images
|
||||
|
||||
|
||||
def eval_score(gt, pred, answer_type):
|
||||
if answer_type == 'Int':
|
||||
try:
|
||||
gt, pred = int(gt), int(float(pred))
|
||||
except:
|
||||
pred = ''
|
||||
score = (gt == pred)
|
||||
elif answer_type == 'Float':
|
||||
try:
|
||||
gt = float(get_clean_string(str(gt)))
|
||||
pred = float(get_clean_string(str(pred)))
|
||||
except:
|
||||
pred = ''
|
||||
score = is_float_equal(gt, pred, include_percentage=True, is_close=True)
|
||||
elif answer_type == 'Str':
|
||||
gt = get_clean_string(gt)
|
||||
pred = get_clean_string(pred)
|
||||
if is_exact_match(gt):
|
||||
score = (gt == pred)
|
||||
else:
|
||||
score = anls_compute(gt, pred)
|
||||
else:
|
||||
if isinstance(gt, str) and gt.startswith('['):
|
||||
gt = eval(gt)
|
||||
if not isinstance(gt, list):
|
||||
gt = [gt]
|
||||
if isinstance(pred, str) and pred.startswith('['):
|
||||
pred = eval(pred)
|
||||
if not isinstance(pred, list):
|
||||
pred = [pred]
|
||||
print(len(gt), len(pred))
|
||||
if len(gt) != len(pred):
|
||||
score = 0.0
|
||||
else:
|
||||
gt = sorted([get_clean_string(a) for a in gt])
|
||||
pred = sorted([get_clean_string(a) for a in pred])
|
||||
print(gt, pred)
|
||||
if isfloat(gt[0]) or is_exact_match(gt[0]):
|
||||
score = ('-'.join(gt) == '-'.join(pred))
|
||||
else:
|
||||
score = min([anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred)])
|
||||
|
||||
return float(score)
|
||||
|
||||
|
||||
def MMLongBench_auxeval(model, line):
|
||||
prompt = build_mmlongbench_gpt4_prompt(line)
|
||||
log = ''
|
||||
retry = 5
|
||||
|
||||
for i in range(retry):
|
||||
prediction = line['prediction']
|
||||
res = model.generate(prompt, temperature=i * 0.5)
|
||||
|
||||
if FAIL_MSG in res:
|
||||
log += f'Try {i}: output is {prediction}, failed to parse.\n'
|
||||
else:
|
||||
log += 'Succeed'
|
||||
try:
|
||||
pred = res.split('Answer format:')[0].split('Extracted answer:')[1].strip()
|
||||
except:
|
||||
pred = ''
|
||||
return dict(log=log, res=res, pred=pred)
|
||||
log += 'All 5 retries failed.\n'
|
||||
return dict(log=log, res='', pred='')
|
||||
|
||||
|
||||
def get_f1(data):
|
||||
gt_pos_data = data[data.apply(lambda k: k['answer'] != 'Not answerable', axis=1)]
|
||||
pred_pos_data = data[data.apply(lambda k: k['pred'] != 'Not answerable', axis=1)]
|
||||
recall = sum(gt_pos_data['score'].tolist()) / len(gt_pos_data)
|
||||
precision = sum(pred_pos_data['score'].tolist()) / len(pred_pos_data)
|
||||
return 2 * recall * precision / (recall + precision)
|
||||
|
||||
|
||||
def MMLongBench_acc(result_file):
|
||||
data = load(result_file)
|
||||
overall_score = 0.0
|
||||
score_list = list()
|
||||
for i in range(len(data)):
|
||||
item = data.iloc[i]
|
||||
try:
|
||||
score = eval_score(item['answer'], item['pred'], item['answer_format'])
|
||||
except:
|
||||
score = 0.0
|
||||
score_list.append(score)
|
||||
overall_score += score
|
||||
|
||||
data['score'] = score_list
|
||||
dump(data, result_file)
|
||||
|
||||
data_chart = data[data.apply(lambda k: 'Chart' in eval(k['evidence_sources']), axis=1)]
|
||||
data_table = data[data.apply(lambda k: 'Table' in eval(k['evidence_sources']), axis=1)]
|
||||
data_image = data[data.apply(lambda k: 'Figure' in eval(k['evidence_sources']), axis=1)]
|
||||
data_text = data[data.apply(lambda k: 'Pure-text (Plain-text)' in eval(k['evidence_sources']), axis=1)]
|
||||
data_layout = data[data.apply(lambda k: 'Generalized-text (Layout)' in eval(k['evidence_sources']), axis=1)]
|
||||
|
||||
data_single = data[data.apply(lambda k: len(eval(k['evidence_pages'])) == 1, axis=1)]
|
||||
data_multi = data[data.apply(lambda k: len(eval(k['evidence_pages'])) > 1, axis=1)]
|
||||
data_unans = data[data.apply(lambda k: len(eval(k['evidence_pages'])) == 0, axis=1)]
|
||||
|
||||
res = dict()
|
||||
res['category'] = [
|
||||
'overall_f1', 'overall_acc', 'text', 'layout', 'table', 'chart',
|
||||
'image', 'single-page', 'multi-page', 'unanswerable'
|
||||
]
|
||||
res['num'] = [
|
||||
len(data), len(data), len(data_text), len(data_layout), len(data_table),
|
||||
len(data_chart), len(data_image), len(data_single), len(data_multi), len(data_unans)
|
||||
]
|
||||
res['avg_score'] = [
|
||||
get_f1(data),
|
||||
overall_score / len(data),
|
||||
sum(data_text['score'].tolist()) / len(data_text) if len(data_text) > 0 else 0.0,
|
||||
sum(data_layout['score'].tolist()) / len(data_layout) if len(data_layout) > 0 else 0.0,
|
||||
sum(data_table['score'].tolist()) / len(data_table) if len(data_table) > 0 else 0.0,
|
||||
sum(data_chart['score'].tolist()) / len(data_chart) if len(data_chart) > 0 else 0.0,
|
||||
sum(data_image['score'].tolist()) / len(data_image) if len(data_image) > 0 else 0.0,
|
||||
sum(data_single['score'].tolist()) / len(data_single) if len(data_single) > 0 else 0.0,
|
||||
sum(data_multi['score'].tolist()) / len(data_multi) if len(data_multi) > 0 else 0.0,
|
||||
sum(data_unans['score'].tolist()) / len(data_unans) if len(data_unans) > 0 else 0.0,
|
||||
]
|
||||
res = pd.DataFrame(res)
|
||||
return res
|
||||
|
||||
|
||||
class MMLongBench(ImageBaseDataset):
|
||||
|
||||
TYPE = 'VQA'
|
||||
|
||||
DATASET_URL = {
|
||||
'MMLongBench_DOC': 'https://opencompass.openxlab.space/utils/VLMEval/MMLongBench_DOC.tsv',
|
||||
}
|
||||
DATASET_MD5 = {
|
||||
'MMLongBench_DOC': '9b393e1f4c52718380d50586197eac9b',
|
||||
}
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
'GPT4': (1, 1),
|
||||
'GPT4V': (1, 1),
|
||||
'GPT4V_HIGH': (1, 1),
|
||||
'GPT4o': (1, 1),
|
||||
'GPT4o_HIGH': (1, 1),
|
||||
'GPT4o_MINI': (1, 1),
|
||||
'MiniCPM-Llama3-V-2_5': (1, 5),
|
||||
'InternVL-Chat-V1-5': (5, 2),
|
||||
'XComposer2_4KHD': (1, 5),
|
||||
'XComposer2d5': (1, -1),
|
||||
}
|
||||
|
||||
def __init__(self, dataset, **kwargs):
|
||||
self.model_list = list(self.SUPPORTED_MODELS.keys())
|
||||
model_name = kwargs['model']
|
||||
if not listinstr(self.model_list, model_name):
|
||||
raise AssertionError("{} doesn't support the evaluation on MMLongBench_DOC.".format(model_name))
|
||||
super(MMLongBench, self).__init__(dataset)
|
||||
|
||||
self.is_api = True if listinstr(['GPT4'], model_name) else False
|
||||
self.max_pages = 120
|
||||
concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
|
||||
self.concat_num = concat_num
|
||||
self.column_num = column_num
|
||||
|
||||
def dump_image(self, origin_line):
|
||||
os.makedirs(self.img_root, exist_ok=True)
|
||||
try:
|
||||
import fitz
|
||||
except Exception as e:
|
||||
logging.critical(f'{type(e)}: {e}')
|
||||
logging.critical('Please use `pip install pymupdf` to parse PDF files.')
|
||||
|
||||
line = origin_line.copy()
|
||||
line['image_path'] = line['image_path'][:self.max_pages]
|
||||
skip_pdf_parse = True
|
||||
for im_name in line['image_path']:
|
||||
path = osp.join(self.img_root, im_name)
|
||||
if not read_ok(path):
|
||||
skip_pdf_parse = False
|
||||
break
|
||||
|
||||
# Just for being compatible with the zooped loop: zip(line['image'], line['image_path'])
|
||||
if skip_pdf_parse:
|
||||
line['image'] = line['image_path']
|
||||
else:
|
||||
pdf_data = base64.b64decode(line['image'])
|
||||
pdf_file = io.BytesIO(pdf_data)
|
||||
encoded_images = []
|
||||
with fitz.open(stream=pdf_file, filetype='pdf') as doc:
|
||||
doc = doc[:self.max_pages]
|
||||
for page in doc:
|
||||
image = page.get_pixmap(dpi=144)
|
||||
image_file = io.BytesIO(image.tobytes(output='png'))
|
||||
image = Image.open(image_file)
|
||||
encoded_image = encode_image_to_base64(image)
|
||||
encoded_images.append(encoded_image)
|
||||
line['image'] = encoded_images
|
||||
print('process {}'.format(line['doc_id']))
|
||||
|
||||
if 'image' in line:
|
||||
if isinstance(line['image'], list):
|
||||
tgt_path = []
|
||||
assert 'image_path' in line
|
||||
for img, im_name in zip(line['image'], line['image_path']):
|
||||
path = osp.join(self.img_root, im_name)
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(img, path)
|
||||
tgt_path.append(path)
|
||||
else:
|
||||
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
||||
if not read_ok(tgt_path):
|
||||
decode_base64_to_image_file(line['image'], tgt_path)
|
||||
tgt_path = [tgt_path]
|
||||
else:
|
||||
assert 'image_path' in line
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
|
||||
if self.concat_num > 0 and not self.is_api:
|
||||
concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
|
||||
|
||||
old_tgt_path = tgt_path
|
||||
assert isinstance(old_tgt_path, list)
|
||||
if self.column_num != -1:
|
||||
tgt_path = [
|
||||
'_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
|
||||
for i in range(len(concatenated_images))
|
||||
]
|
||||
else:
|
||||
tgt_path = [
|
||||
'_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all_{}.jpg'.format(i)
|
||||
for i in range(len(concatenated_images))
|
||||
]
|
||||
|
||||
for path, concatenated_image in zip(tgt_path, concatenated_images):
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
|
||||
num_images, image_size = len(old_tgt_path), concatenated_image.size
|
||||
print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
|
||||
return tgt_path
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
logger = get_logger('Evaluation')
|
||||
model = judge_kwargs['model']
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
||||
|
||||
if osp.exists(storage):
|
||||
logger.warning(f'GPT scoring file {storage} already exists, will reuse it in MMLongBench_eval. ')
|
||||
else:
|
||||
data = load(eval_file)
|
||||
model = build_judge(max_tokens=128, **judge_kwargs)
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
tups = [(model, line) for line in lines]
|
||||
indices = [line['index'] for line in lines]
|
||||
|
||||
ans = {}
|
||||
if osp.exists(tmp_file):
|
||||
ans = load(tmp_file)
|
||||
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
||||
indices = [i for i in indices if i not in ans]
|
||||
|
||||
if len(indices):
|
||||
new_results = list()
|
||||
for model, line in tqdm(tups):
|
||||
res = MMLongBench_auxeval(model, line)
|
||||
new_results.append(res)
|
||||
|
||||
log_map, res_map, pred_map = {}, {}, {}
|
||||
all_inds = [line['index'] for line in lines]
|
||||
for k, v in zip(all_inds, new_results):
|
||||
log_map[k] = v['log']
|
||||
res_map[k] = v['res']
|
||||
pred_map[k] = v['pred']
|
||||
data['res'] = [res_map[idx] for idx in data['index']]
|
||||
data['log'] = [log_map[idx] for idx in data['index']]
|
||||
data['pred'] = [pred_map[idx] for idx in data['index']]
|
||||
dump(data, storage)
|
||||
|
||||
score = MMLongBench_acc(storage)
|
||||
score_pth = storage.replace('.xlsx', '_score.csv')
|
||||
|
||||
dump(score, score_pth)
|
||||
logger.info(f'MMLongBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
|
||||
logger.info('Score: ')
|
||||
logger.info(score)
|
||||
446
eval_mm/vlmevalkit/vlmeval/dataset/mmmath.py
Normal file
@@ -0,0 +1,446 @@
|
||||
import re
|
||||
import json
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
from sympy import simplify, Eq, sympify, Pow, pi
|
||||
from sympy.parsing.latex import parse_latex
|
||||
import sys
|
||||
import math
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from .image_base import ImageBaseDataset
|
||||
from ..utils import track_progress_rich
|
||||
from ..smp import load, dump
|
||||
|
||||
|
||||
class AutoScoringJudge:
|
||||
def __init__(self):
|
||||
# Map of special symbols to their replacements
|
||||
self.special_signal_map = {
|
||||
"\\left": "",
|
||||
"\\right": "",
|
||||
"厘米":"",
|
||||
# "∶": ":",
|
||||
",": ",",
|
||||
"$": "",
|
||||
"(":"(",
|
||||
")":")",
|
||||
"\\infty":"oo",
|
||||
"\\colon ":":",
|
||||
# "\\approx": "=",
|
||||
# "\\simeq": "=",
|
||||
# "\\sim": "=",
|
||||
# "^\\prime": "'",
|
||||
# "^{\\prime}": "'",
|
||||
"+":"+",
|
||||
"\\, ": "",
|
||||
"\\,":"",
|
||||
"^\\circ": "",
|
||||
"^{\\circ}": "",
|
||||
# "%": "",
|
||||
}
|
||||
self.pi = parse_latex("\\pi")
|
||||
# MM-Math default precision
|
||||
self.precision = 1e-2
|
||||
|
||||
def trans_greater_sign_to_interval(self, expr:str):
|
||||
expr_tmp = expr.split("<")
|
||||
return "(" + expr_tmp[0] + ", " + expr_tmp[-1] + ")"
|
||||
|
||||
def split_by_comma(self, expr: str):
|
||||
# Splits expressions by commas outside of brackets
|
||||
in_bracket_num = 0
|
||||
splitted_expr = []
|
||||
start_idx = 0
|
||||
for i, char in enumerate(expr):
|
||||
if char in ["(", "["]:
|
||||
in_bracket_num += 1
|
||||
elif char in [")", "]"]:
|
||||
in_bracket_num -= 1
|
||||
elif char == "," and in_bracket_num == 0:
|
||||
splitted_expr.append(expr[start_idx:i].strip())
|
||||
start_idx = i + 1
|
||||
|
||||
if start_idx < len(expr):
|
||||
splitted_expr.append(expr[start_idx:].strip())
|
||||
|
||||
return splitted_expr
|
||||
|
||||
def trans_plus_minus_sign(self, expr_list: list):
|
||||
# Translates plus-minus signs into separate expressions
|
||||
new_expr_list = []
|
||||
for expr in expr_list:
|
||||
if "\\pm" in expr:
|
||||
new_expr_list.append(expr.replace("\\pm", "+"))
|
||||
new_expr_list.append(expr.replace("\\pm", "-"))
|
||||
else:
|
||||
new_expr_list.append(expr)
|
||||
|
||||
return new_expr_list
|
||||
|
||||
def judge(self, expression1, expression2, precision=1e-2):
|
||||
# Judge if two expressions are equal (expression1 is considered as the Ground Truth)
|
||||
# Default precision is a list for supporting multiple expressions
|
||||
precision = precision if isinstance(precision, list) else [precision]
|
||||
|
||||
try:
|
||||
expression1, expression2 = self.preprocess(expression1, expression2)
|
||||
except:
|
||||
return False
|
||||
if expression1 == expression2:
|
||||
# print("Exactly equal")
|
||||
return True
|
||||
|
||||
# Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered
|
||||
expression1 = expression1 if re.fullmatch(r"[\u4e00-\u9fff]+", expression1) else re.sub(r'[\u4e00-\u9fff]+', '', expression1) # noqa: E501
|
||||
expression2 = expression2 if re.fullmatch(r'[\u4e00-\u9fff]+', expression2) else re.sub(r'[\u4e00-\u9fff]+', '', expression2) # noqa: E501
|
||||
# Check if two < or > in expression
|
||||
if self.is_two_greater_sign(expression1):
|
||||
expression1 = self.trans_greater_sign_to_interval(expression1)
|
||||
|
||||
if self.is_two_greater_sign(expression2):
|
||||
expression2 = self.trans_greater_sign_to_interval(expression2)
|
||||
|
||||
expression1 = self.split_by_comma(expression1)
|
||||
expression2 = self.split_by_comma(expression2)
|
||||
|
||||
temp_list1 = self.trans_plus_minus_sign(expression1)
|
||||
temp_list2 = self.trans_plus_minus_sign(expression2)
|
||||
|
||||
# Set up a list for allowed errors
|
||||
if len(precision) <= 1:
|
||||
precision = precision * len(temp_list1)
|
||||
|
||||
if len(temp_list1) != len(temp_list2):
|
||||
return False
|
||||
|
||||
# Check if elements in both lists can be paired and are equal
|
||||
idx = -1
|
||||
while len(temp_list1) != 0:
|
||||
idx = (idx + 1) % len(temp_list1)
|
||||
|
||||
item1 = temp_list1[idx]
|
||||
self.precision = precision[idx]
|
||||
|
||||
for item2 in temp_list2:
|
||||
if self.is_equal(item1, item2):
|
||||
temp_list1.remove(item1)
|
||||
temp_list2.remove(item2)
|
||||
precision.remove(self.precision)
|
||||
break
|
||||
else:
|
||||
# If no match was found, return False
|
||||
return False
|
||||
|
||||
# If all elements are matched, return True
|
||||
return True
|
||||
|
||||
def is_interval(self, expr):
|
||||
# Checks if an expression is an interval
|
||||
return expr.startswith(("(", "[")) and expr.endswith((")", "]"))
|
||||
|
||||
def is_two_greater_sign(self, expr):
|
||||
match = re.findall(r'<', expr)
|
||||
return len(match) == 2
|
||||
|
||||
def sympy_sub_pi(self, expression_sympy):
|
||||
# Replaces the symbol for pi in sympy expressions with its numerical value
|
||||
return expression_sympy.subs(self.pi, math.pi)
|
||||
|
||||
def is_equal(self, expression1, expression2):
|
||||
# Default first expression is ground truth. Check if expressions are equal in different aspects
|
||||
if expression1 == expression2 and expression1 != "" and expression2 != "":
|
||||
# print("Equivalent natively")
|
||||
return True
|
||||
|
||||
# First check if both are intervals
|
||||
if self.is_interval(expression1) and self.is_interval(expression2):
|
||||
try:
|
||||
if self.interval_equal(expression1, expression2):
|
||||
# print("Interval equivalent")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
# Then check for numerical equality
|
||||
try:
|
||||
if self.numerical_equal(expression1, expression2):
|
||||
# print("Numerically equivalent")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
# Then check if expressions are mathematically equal
|
||||
try:
|
||||
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
|
||||
# print("Expression equivalent")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Lastly, check for equation equality
|
||||
try:
|
||||
if self.equation_equal(expression1, expression2):
|
||||
# print("Equation equivalent")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
|
||||
# Check if two numerical values are equal within an allowed error range
|
||||
# Includes possible percentage cases
|
||||
reference = float(expression1)
|
||||
prediction = float(expression2)
|
||||
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
|
||||
for item in gt_result:
|
||||
if abs(item - prediction) <= self.precision * 1.01:
|
||||
return True
|
||||
return False
|
||||
|
||||
def expression_equal(self, exp1, exp2):
|
||||
# Check if two expressions are mathematically equivalent
|
||||
# Extract expression and use sympy for equivalence checking
|
||||
def extract_expression(expression):
|
||||
if "=" in expression:
|
||||
expression = expression.split("=")[1]
|
||||
return expression.strip()
|
||||
|
||||
exp1 = extract_expression(exp1)
|
||||
exp2 = extract_expression(exp2)
|
||||
|
||||
exp_too_long = len(exp1) > 300 or len(exp2) > 300
|
||||
|
||||
expr1_sym = sympify(parse_latex(exp1))
|
||||
expr2_sym = sympify(parse_latex(exp2))
|
||||
if expr1_sym == expr2_sym:
|
||||
return True
|
||||
else:
|
||||
expr1_sym = self.sympy_sub_pi(expr1_sym)
|
||||
expr2_sym = self.sympy_sub_pi(expr2_sym)
|
||||
|
||||
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or \
|
||||
(not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
|
||||
return False
|
||||
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
|
||||
try:
|
||||
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
|
||||
print("These two numbers cannot be calculated by the current computer for: "
|
||||
f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"")
|
||||
return False
|
||||
if exp_too_long:
|
||||
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
||||
return False
|
||||
if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
elif exp_too_long:
|
||||
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
simplified_expr = simplify(expr1_sym - expr2_sym)
|
||||
num_value = simplified_expr.evalf()
|
||||
return abs(num_value) < 1e-3
|
||||
except:
|
||||
return False
|
||||
|
||||
def equation_equal(self, expression1, expression2):
|
||||
# Check if two equations are mathematically equivalent
|
||||
# Simplify equations and use sympy for equivalence checking
|
||||
def simplify_equation(latex_eq):
|
||||
lhs, rhs = latex_eq.split('=')
|
||||
|
||||
lhs_expr = parse_latex(lhs)
|
||||
rhs_expr = parse_latex(rhs)
|
||||
|
||||
equation = Eq(lhs_expr, rhs_expr)
|
||||
|
||||
simplified_eq = simplify(equation.lhs - equation.rhs)
|
||||
|
||||
return simplified_eq
|
||||
|
||||
expr1_sym = simplify_equation(expression1)
|
||||
expr2_sym = simplify_equation(expression2)
|
||||
|
||||
division_result_1 = simplify(expr1_sym / expr2_sym)
|
||||
division_result_2 = simplify(expr2_sym / expr1_sym)
|
||||
|
||||
if ((division_result_1.is_Integer and division_result_1 != 0) or # noqa: W504
|
||||
(division_result_2.is_Integer and division_result_2 != 0)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def interval_equal(self, expression1, expression2):
|
||||
# Check if two intervals are mathematically equivalent
|
||||
def compare_two_interval(inter1, inter2):
|
||||
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
|
||||
return False
|
||||
|
||||
inter1 = inter1.strip('[]()')
|
||||
inter2 = inter2.strip('[]()')
|
||||
|
||||
items_1 = inter1.split(',')
|
||||
items_2 = inter2.split(',')
|
||||
|
||||
for item_1, item_2 in zip(items_1, items_2):
|
||||
if not self.expression_equal(item_1, item_2):
|
||||
return False
|
||||
return True
|
||||
|
||||
interval1 = expression1
|
||||
interval2 = expression2
|
||||
|
||||
if interval1 == interval2:
|
||||
return True
|
||||
else:
|
||||
inter_list1 = interval1.split("\\cup")
|
||||
inter_list2 = interval2.split("\\cup")
|
||||
|
||||
if len(inter_list1) != len(inter_list2):
|
||||
return False
|
||||
else:
|
||||
for inter1, inter2 in zip(inter_list1, inter_list2):
|
||||
if not compare_two_interval(inter1, inter2):
|
||||
return False
|
||||
return True
|
||||
|
||||
def preprocess(self, expression1, expression2):
|
||||
# Preprocess expressions to extract and replace special symbols
|
||||
def extract_boxed_content(latex_str):
|
||||
boxed_matches = re.finditer(r'\\boxed{', latex_str)
|
||||
results = ""
|
||||
|
||||
for match in boxed_matches:
|
||||
start_index = match.end()
|
||||
end_index = start_index
|
||||
stack = 1
|
||||
|
||||
while stack > 0 and end_index < len(latex_str):
|
||||
if latex_str[end_index] == '{':
|
||||
stack += 1
|
||||
elif latex_str[end_index] == '}':
|
||||
stack -= 1
|
||||
end_index += 1
|
||||
|
||||
if stack == 0:
|
||||
content = latex_str[start_index:end_index - 1]
|
||||
results += content + ","
|
||||
else:
|
||||
raise ValueError("Mismatched braces in LaTeX string.")
|
||||
|
||||
if results == "":
|
||||
last_line_ans = latex_str.strip().split("\n")[-1]
|
||||
dollar_pattern = r"\$(.*?)\$"
|
||||
answers = re.findall(dollar_pattern, last_line_ans)
|
||||
|
||||
if answers:
|
||||
for ans in answers:
|
||||
results += ans + ","
|
||||
else:
|
||||
results = latex_str
|
||||
|
||||
return results
|
||||
|
||||
def sepcial_symbol_replace(expression):
|
||||
|
||||
expression = expression.replace("\\text{cm}^2", '').replace("\\text{cm}", "").replace("\\,cm", '').replace("\\text{ cm}", '').replace("cm", '').replace("\\text{分米}^2", '').replace("cm^{2}", '').replace("60 \\text{ cm}^2",'').replace("\\ \\text{m}", "").replace("\\text{米}","").strip() # noqa: E501
|
||||
|
||||
expression = re.sub(r"(.+)m$", r"\1", expression)
|
||||
|
||||
if "\\in " in expression:
|
||||
expression = expression.split("\\in ")[1]
|
||||
|
||||
for signal in self.special_signal_map:
|
||||
expression = expression.replace(signal, self.special_signal_map[signal])
|
||||
|
||||
expression = re.sub(r'(\\sin|\\cos|\\tan)(\d+)', r'\1((\2/180)\\pi)', expression)
|
||||
|
||||
expression = expression.strip("\n,.:;^_=+`!@#%^&*~,。")
|
||||
|
||||
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
|
||||
expression = re.sub(pattern, r'\1', expression)
|
||||
|
||||
return expression
|
||||
|
||||
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
|
||||
|
||||
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
|
||||
|
||||
return exp1, exp2
|
||||
|
||||
def can_compute_power(self, expr):
|
||||
# Checks if a power expression can be computed
|
||||
if isinstance(expr, Pow):
|
||||
base, exp = expr.as_base_exp()
|
||||
if base.is_number and exp.is_number:
|
||||
MAX_EXP = 1000 # Adjust based on computing environment
|
||||
if abs(exp.evalf()) > MAX_EXP:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True # Not a power expression, can compute
|
||||
|
||||
|
||||
class MMMath(ImageBaseDataset):
|
||||
|
||||
TYPE = 'VQA'
|
||||
|
||||
DATASET_URL = {
|
||||
'MM-Math': 'https://opencompass.openxlab.space/utils/VLMEval/MM-Math.tsv',
|
||||
}
|
||||
DATASET_MD5 = {
|
||||
'MM-Math': '1f064ed7c4e0e8926a3fa65849419ca5',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **kwargs):
|
||||
|
||||
data = load(eval_file)
|
||||
judger = AutoScoringJudge()
|
||||
func = judger.judge
|
||||
|
||||
tups = [dict(expression1=x, expression2=y) for x, y in zip(data['answer'], data['prediction'])]
|
||||
|
||||
res = track_progress_rich(func, tups, nproc=16)
|
||||
data['hit'] = res
|
||||
dump(data, eval_file)
|
||||
|
||||
score_file = eval_file.replace('.xlsx', '_score.json')
|
||||
score = {}
|
||||
score['overall'] = np.mean(data['hit'])
|
||||
# Results by Difficulty
|
||||
difficulties = set(data['difficulty'])
|
||||
for d in difficulties:
|
||||
score[f'Difficulty-{d}'] = np.mean(data[data['difficulty'] == d]['hit'])
|
||||
|
||||
# Results by Year
|
||||
years = set(data['year'])
|
||||
for y in years:
|
||||
score[f'Year-{y}'] = np.mean(data[data['year'] == y]['hit'])
|
||||
|
||||
# Results by Knowledge-L1
|
||||
points = set(data['knowledge_l1'])
|
||||
for p in points:
|
||||
score[f'Knowledge-L1-{p}'] = np.mean(data[data['knowledge_l1'] == p]['hit'])
|
||||
|
||||
# Results by Knowledge-L2
|
||||
points = set(data['knowledge_l2'])
|
||||
for p in points:
|
||||
score[f'Knowledge-L2-{p}'] = np.mean(data[data['knowledge_l2'] == p]['hit'])
|
||||
|
||||
dump(score, score_file)
|
||||
return score
|
||||
666
eval_mm/vlmevalkit/vlmeval/dataset/mvbench.py
Normal file
@@ -0,0 +1,666 @@
|
||||
import huggingface_hub
|
||||
from huggingface_hub import snapshot_download
|
||||
from ..smp import *
|
||||
from .video_base import VideoBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from ..utils import track_progress_rich
|
||||
import torchvision.transforms as T
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from decord import VideoReader, cpu
|
||||
import imageio
|
||||
import cv2
|
||||
import zipfile
|
||||
import os
|
||||
import glob
|
||||
from .utils.mvbench import *
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
class MVBench(VideoBaseDataset):
|
||||
|
||||
MD5 = 'fd21d36522cdedd46d84dc46715ad832'
|
||||
SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \
|
||||
the detail and movement of objects, and the action and pose of persons. \
|
||||
Based on your observations, select the best option that accurately addresses the question.
|
||||
"""
|
||||
|
||||
TYPE = 'Video-MCQ'
|
||||
|
||||
def __init__(self, dataset='MVBench', nframe=0, fps=-1):
|
||||
self.type_data_list = {
|
||||
'Action Sequence': ('action_sequence.json',
|
||||
'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
|
||||
'Action Prediction': ('action_prediction.json',
|
||||
'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
|
||||
'Action Antonym': ('action_antonym.json',
|
||||
'your_data_path/ssv2_video/', 'video', False),
|
||||
'Fine-grained Action': ('fine_grained_action.json',
|
||||
'your_data_path/Moments_in_Time_Raw/videos/', 'video', False),
|
||||
'Unexpected Action': ('unexpected_action.json',
|
||||
'your_data_path/FunQA_test/test/', 'video', False),
|
||||
'Object Existence': ('object_existence.json',
|
||||
'your_data_path/clevrer/video_validation/', 'video', False),
|
||||
'Object Interaction': ('object_interaction.json',
|
||||
'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
|
||||
'Object Shuffle': ('object_shuffle.json',
|
||||
'your_data_path/perception/videos/', 'video', False),
|
||||
'Moving Direction': ('moving_direction.json',
|
||||
'your_data_path/clevrer/video_validation/', 'video', False),
|
||||
'Action Localization': ('action_localization.json',
|
||||
'your_data_path/sta/sta_video/', 'video', True), # has start & end
|
||||
'Scene Transition': ('scene_transition.json',
|
||||
'your_data_path/scene_qa/video/', 'video', False),
|
||||
'Action Count': ('action_count.json',
|
||||
'your_data_path/perception/videos/', 'video', False),
|
||||
'Moving Count': ('moving_count.json',
|
||||
'your_data_path/clevrer/video_validation/', 'video', False),
|
||||
'Moving Attribute': ('moving_attribute.json',
|
||||
'your_data_path/clevrer/video_validation/', 'video', False),
|
||||
'State Change': ('state_change.json',
|
||||
'your_data_path/perception/videos/', 'video', False),
|
||||
'Fine-grained Pose': ('fine_grained_pose.json',
|
||||
'your_data_path/nturgbd/', 'video', False),
|
||||
'Character Order': ('character_order.json',
|
||||
'your_data_path/perception/videos/', 'video', False),
|
||||
'Egocentric Navigation': ('egocentric_navigation.json',
|
||||
'your_data_path/vlnqa/', 'video', False),
|
||||
'Episodic Reasoning': ('episodic_reasoning.json',
|
||||
'your_data_path/tvqa/frames_fps3_hq/', 'frame', True), # has start & end, read frame
|
||||
'Counterfactual Inference': ('counterfactual_inference.json',
|
||||
'your_data_path/clevrer/video_validation/', 'video', False),
|
||||
}
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['MVBench']
|
||||
|
||||
def prepare_dataset(self, dataset_name='MVBench', repo_id='OpenGVLab/MVBench'):
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
|
||||
if not os.path.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.MD5:
|
||||
return False
|
||||
|
||||
data = load(data_file)
|
||||
for idx, item in data.iterrows():
|
||||
if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
|
||||
return False
|
||||
return True
|
||||
|
||||
if modelscope_flag_set():
|
||||
repo_id = 'modelscope/MVBench'
|
||||
|
||||
cache_path = get_cache_path(repo_id, branch='main')
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
def unzip_hf_zip(pth):
|
||||
pth = os.path.join(pth, 'video/')
|
||||
for filename in os.listdir(pth):
|
||||
if filename.endswith('.zip'):
|
||||
# 构建完整的文件路径
|
||||
zip_path = os.path.join(pth, filename)
|
||||
|
||||
# 解压 ZIP 文件
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(pth)
|
||||
|
||||
def generate_tsv(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if os.path.exists(data_file) and md5(data_file) == self.MD5:
|
||||
return
|
||||
json_data_dir = os.path.join(pth, 'json')
|
||||
self.data_list = []
|
||||
for k, v in self.type_data_list.items():
|
||||
with open(os.path.join(json_data_dir, v[0]), 'r') as f:
|
||||
json_data = json.load(f)
|
||||
for data in json_data:
|
||||
if os.path.exists(os.path.join(pth, v[1].replace('your_data_path', 'video'), data['video'])):
|
||||
self.data_list.append({
|
||||
'task_type': k,
|
||||
'prefix': v[1].replace('your_data_path', 'video'),
|
||||
'data_type': v[2],
|
||||
'bound': v[3],
|
||||
'start': data['start'] if 'start' in data.keys() else None,
|
||||
'end': data['end'] if 'end' in data.keys() else None,
|
||||
'video': data['video'],
|
||||
'question': data['question'],
|
||||
'answer': data['answer'],
|
||||
'candidates': data['candidates']
|
||||
})
|
||||
else:
|
||||
print(
|
||||
'NTURGB-D zip file is removed according to MVBench, you can view it at '
|
||||
'https://huggingface.co/datasets/OpenGVLab/MVBench for detailed reason.'
|
||||
)
|
||||
raise Exception(
|
||||
f"{os.path.join(v[1].replace('your_data_path', 'video'), data['video'])} does not exist"
|
||||
)
|
||||
|
||||
data_df = pd.DataFrame(self.data_list)
|
||||
data_df = data_df.assign(index=range(len(data_df)))
|
||||
data_df.to_csv(data_file, sep='\t', index=False)
|
||||
|
||||
def move_files(pth):
|
||||
src_folder = os.path.join(pth, 'video/data0613')
|
||||
if not os.path.exists(src_folder):
|
||||
return
|
||||
for subdir in os.listdir(src_folder):
|
||||
subdir_path = os.path.join(src_folder, subdir)
|
||||
if os.path.isdir(subdir_path):
|
||||
for subsubdir in os.listdir(subdir_path):
|
||||
subsubdir_path = os.path.join(subdir_path, subsubdir)
|
||||
if os.path.isdir(subsubdir_path):
|
||||
for item in os.listdir(subsubdir_path):
|
||||
item_path = os.path.join(subsubdir_path, item)
|
||||
target_folder = os.path.join(pth, 'video', subdir, subsubdir)
|
||||
if not os.path.exists(target_folder):
|
||||
os.makedirs(target_folder)
|
||||
target_path = os.path.join(target_folder, item)
|
||||
try:
|
||||
shutil.move(item_path, target_path)
|
||||
except Exception as e:
|
||||
print(f"Error moving {item_path} to {target_path}: {e}")
|
||||
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_path = dataset_snapshot_download(dataset_id=repo_id, revision='master')
|
||||
else:
|
||||
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
|
||||
huggingface_hub.login(hf_token)
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
||||
unzip_hf_zip(dataset_path)
|
||||
move_files(dataset_path)
|
||||
generate_tsv(dataset_path)
|
||||
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
|
||||
self.decord_method = {
|
||||
'video': self.read_video,
|
||||
'gif': self.read_gif,
|
||||
'frame': self.read_frame,
|
||||
}
|
||||
|
||||
self.nframe = 8
|
||||
self.frame_fps = 3
|
||||
|
||||
# transform
|
||||
self.transform = T.Compose([
|
||||
Stack(),
|
||||
ToTorchFormatTensor()
|
||||
])
|
||||
|
||||
return dict(root=dataset_path, data_file=data_file)
|
||||
|
||||
def get_index(self, bound, fps, max_frame, first_idx=0):
|
||||
if bound:
|
||||
start, end = bound[0], bound[1]
|
||||
else:
|
||||
start, end = -100000, 100000
|
||||
start_idx = max(first_idx, round(start * fps))
|
||||
end_idx = min(round(end * fps), max_frame)
|
||||
seg_size = float(end_idx - start_idx) / self.num_segments
|
||||
frame_indices = np.array([
|
||||
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
||||
for idx in range(self.num_segments)
|
||||
])
|
||||
return frame_indices
|
||||
|
||||
def read_video(self, video_path, bound=None):
|
||||
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
max_frame = len(vr) - 1
|
||||
fps = float(vr.get_avg_fps())
|
||||
|
||||
images_group = list()
|
||||
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
|
||||
for frame_index in frame_indices:
|
||||
img = Image.fromarray(vr[frame_index].asnumpy())
|
||||
images_group.append(img)
|
||||
torch_imgs = self.transform(images_group)
|
||||
return torch_imgs
|
||||
|
||||
def read_gif(self, video_path, bound=None, fps=25):
|
||||
gif = imageio.get_reader(video_path)
|
||||
max_frame = len(gif) - 1
|
||||
|
||||
images_group = list()
|
||||
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
|
||||
for index, frame in enumerate(gif):
|
||||
if index in frame_indices:
|
||||
img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
||||
img = Image.fromarray(img)
|
||||
images_group.append(img)
|
||||
torch_imgs = self.transform(images_group)
|
||||
return torch_imgs
|
||||
|
||||
def read_frame(self, video_path, bound=None, fps=3):
|
||||
max_frame = len(os.listdir(video_path))
|
||||
images_group = list()
|
||||
frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1
|
||||
for frame_index in frame_indices:
|
||||
img = Image.open(os.path.join(video_path, f'{frame_index:05d}.jpg'))
|
||||
images_group.append(img)
|
||||
torch_imgs = self.transform(images_group)
|
||||
return torch_imgs
|
||||
|
||||
def save_video_frames(self, imgs, video_name, frames):
|
||||
|
||||
frame_paths = self.frame_paths(video_name)
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
|
||||
if not flag:
|
||||
block_size = imgs.size(0) // frames
|
||||
split_tensors = torch.split(imgs, block_size)
|
||||
to_pil = transforms.ToPILImage()
|
||||
images = [to_pil(arr) for arr in split_tensors]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth):
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths
|
||||
|
||||
def qa_template(self, data):
|
||||
question = f"Question: {data['question']}\n"
|
||||
question += 'Options:\n'
|
||||
answer = data['answer']
|
||||
answer_idx = -1
|
||||
for idx, c in enumerate(eval(data['candidates'])):
|
||||
question += f"({chr(ord('A') + idx)}) {c}\n"
|
||||
if c == answer:
|
||||
answer_idx = idx
|
||||
question = question.rstrip()
|
||||
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
|
||||
return question, answer
|
||||
|
||||
def load_into_video_and_process(self, line):
|
||||
try:
|
||||
from moviepy.editor import VideoFileClip, ImageSequenceClip
|
||||
except:
|
||||
raise ImportError(
|
||||
'MoviePy is not installed, please install it by running "pip install moviepy==1.0.3"'
|
||||
)
|
||||
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
||||
|
||||
if line['data_type'] in ['gif'] or os.path.splitext(video_path)[1] in ['.webm']:
|
||||
processed_video_path = video_path.replace(os.path.splitext(video_path)[1], '.mp4')
|
||||
if not os.path.exists(processed_video_path):
|
||||
# using MoviePy to transform GIF, webm into mp4 format
|
||||
gif_clip = VideoFileClip(video_path)
|
||||
gif_clip.write_videofile(processed_video_path, codec='libx264')
|
||||
gif_clip.close()
|
||||
elif line['data_type'] in ['frame']:
|
||||
input_images = os.path.join(video_path, '*.jpg')
|
||||
processed_video_path = f'{video_path}.mp4'
|
||||
if not os.path.exists(processed_video_path):
|
||||
# using MoviePy to transform images into mp4
|
||||
image_files = sorted(glob.glob(input_images))
|
||||
image_clip = ImageSequenceClip(image_files, fps=self.frame_fps)
|
||||
image_clip.write_videofile(processed_video_path, codec='libx264')
|
||||
image_clip.close()
|
||||
else:
|
||||
processed_video_path = video_path
|
||||
|
||||
if line['bound']:
|
||||
base_name, suffix = os.path.splitext(processed_video_path)
|
||||
output_video_path = f'{base_name}_processed{suffix}'
|
||||
if not os.path.exists(output_video_path):
|
||||
video_clip = VideoFileClip(processed_video_path)
|
||||
clip = video_clip.subclip(line['start'], min(line['end'], video_clip.duration))
|
||||
clip.write_videofile(output_video_path)
|
||||
clip.close()
|
||||
else:
|
||||
output_video_path = processed_video_path
|
||||
|
||||
return output_video_path
|
||||
|
||||
def save_video_into_images(self, line):
|
||||
bound = None
|
||||
if line['bound']:
|
||||
bound = (
|
||||
line['start'],
|
||||
line['end'],
|
||||
)
|
||||
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
||||
decord_method = self.decord_method[line['data_type']]
|
||||
self.num_segments = self.nframe
|
||||
torch_imgs = decord_method(video_path, bound)
|
||||
img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments)
|
||||
return img_frame_paths
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if self.fps > 0:
|
||||
raise ValueError('MVBench does not support fps setting, please transfer to MVBench_MP4!')
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question, answer = self.qa_template(line)
|
||||
message = [dict(type='text', value=self.SYS, role='system')]
|
||||
message.append(dict(type='text', value=question))
|
||||
if video_llm:
|
||||
new_video_path = self.load_into_video_and_process(line)
|
||||
message.append(dict(type='video', value=new_video_path))
|
||||
else:
|
||||
img_frame_paths = self.save_video_into_images(line)
|
||||
for im in img_frame_paths:
|
||||
message.append(dict(type='image', value=im))
|
||||
message.append(dict(type='text', value='\nOnly give the best option.'))
|
||||
message.append(dict(type='text', value='Best option:(', role='assistant'))
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
|
||||
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
||||
|
||||
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
||||
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
||||
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
||||
|
||||
if not osp.exists(score_file):
|
||||
model = judge_kwargs.setdefault('model', 'chatgpt-0125')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
||||
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
||||
|
||||
data = load(eval_file)
|
||||
data_un = data[~pd.isna(data['prediction'])]
|
||||
|
||||
for idx in data_un['index']:
|
||||
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
||||
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
|
||||
options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
|
||||
answer_idx = -1
|
||||
for id, c in enumerate(options):
|
||||
if c == ans:
|
||||
answer_idx = id
|
||||
ans = f"({chr(ord('A') + answer_idx)}) {ans}"
|
||||
input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
|
||||
for id, option_content in enumerate(eval(input_item['candidates'])):
|
||||
input_item[chr(ord('A') + id)] = option_content
|
||||
if option_content == input_item['answer']:
|
||||
input_item['answer'] = chr(ord('A') + id)
|
||||
|
||||
if FAIL_MSG in pred:
|
||||
data.loc[idx, 'score'] = -1
|
||||
else:
|
||||
data.loc[idx, 'score'] = int(check_ans_with_model(
|
||||
pred, ans, model,
|
||||
input_item,
|
||||
'MVBench'
|
||||
))
|
||||
|
||||
rejected = [x for x in data['score'] if x == -1]
|
||||
|
||||
print(
|
||||
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
||||
f'failed to obtain the score for another {len(rejected)} questions. '
|
||||
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
||||
)
|
||||
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
dump(rating, tgt_file)
|
||||
return rating
|
||||
|
||||
|
||||
class MVBench_MP4(VideoBaseDataset):
|
||||
|
||||
MP4_MD5 = '5c8c6f8b7972c2de65a629590f7c42f5'
|
||||
SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \
|
||||
the detail and movement of objects, and the action and pose of persons. \
|
||||
Based on your observations, select the best option that accurately addresses the question.
|
||||
"""
|
||||
TYPE = 'Video-MCQ'
|
||||
|
||||
def __init__(self, dataset='MVBench_MP4', nframe=0, fps=-1):
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['MVBench_MP4']
|
||||
|
||||
def prepare_dataset(self, dataset_name='MVBench_MP4', repo_id='OpenGVLab/MVBench'):
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
|
||||
if not os.path.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.MP4_MD5:
|
||||
return False
|
||||
|
||||
data = load(data_file)
|
||||
for idx, item in data.iterrows():
|
||||
if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
|
||||
return False
|
||||
return True
|
||||
|
||||
if modelscope_flag_set():
|
||||
repo_id = 'modelscope/MVBench'
|
||||
|
||||
cache_path = get_cache_path(repo_id, branch='video')
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
def generate_tsv(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if os.path.exists(data_file) and md5(data_file) == self.MP4_MD5:
|
||||
return
|
||||
json_data_path = os.path.join(dataset_path, 'test.json')
|
||||
json_data = load(json_data_path)
|
||||
root_data_dict = json_data['root']
|
||||
self.data_list = []
|
||||
for k, v in json_data['meta'].items():
|
||||
for item in v:
|
||||
self.data_list.append({
|
||||
'task_type': k,
|
||||
'prefix': root_data_dict[k],
|
||||
'video': item['video'],
|
||||
'question': item['question'],
|
||||
'answer': item['answer'],
|
||||
'candidates': item['candidates']
|
||||
})
|
||||
data_df = pd.DataFrame(self.data_list)
|
||||
data_df = data_df.assign(index=range(len(data_df)))
|
||||
data_df.to_csv(data_file, sep='\t', index=False)
|
||||
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_path = dataset_snapshot_download(dataset_id=repo_id, revision='video')
|
||||
else:
|
||||
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
|
||||
huggingface_hub.login(hf_token)
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset', revision='video')
|
||||
generate_tsv(dataset_path)
|
||||
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
|
||||
# transform
|
||||
self.transform = T.Compose([
|
||||
Stack(),
|
||||
ToTorchFormatTensor()
|
||||
])
|
||||
|
||||
return dict(root=dataset_path, data_file=data_file)
|
||||
|
||||
def qa_template(self, data):
|
||||
question = f"Question: {data['question']}\n"
|
||||
question += 'Options:\n'
|
||||
answer = data['answer']
|
||||
answer_idx = -1
|
||||
for idx, c in enumerate(eval(data['candidates'])):
|
||||
question += f"({chr(ord('A') + idx)}) {c}\n"
|
||||
if c == answer:
|
||||
answer_idx = idx
|
||||
question = question.rstrip()
|
||||
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
|
||||
return question, answer
|
||||
|
||||
def get_index_by_frame(self, max_frame):
|
||||
seg_size = float(max_frame) / self.num_segments
|
||||
frame_indices = np.array([
|
||||
int((seg_size / 2) + np.round(seg_size * idx))
|
||||
for idx in range(self.num_segments)
|
||||
])
|
||||
return frame_indices
|
||||
|
||||
def get_index_by_fps(self, vid, fps):
|
||||
total_frames = len(vid)
|
||||
video_fps = vid.get_avg_fps()
|
||||
total_duration = total_frames / video_fps
|
||||
required_frames = int(total_duration * fps)
|
||||
step_size = video_fps / fps
|
||||
frame_indices = np.array([int(i * step_size) for i in range(required_frames)])
|
||||
self.num_segments = len(frame_indices)
|
||||
return frame_indices
|
||||
|
||||
def read_video(self, video_path):
|
||||
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
max_frame = len(vr) - 1
|
||||
|
||||
images_group = list()
|
||||
if self.fps < 0:
|
||||
frame_indices = self.get_index_by_frame(max_frame)
|
||||
else:
|
||||
frame_indices = self.get_index_by_fps(vr, self.fps)
|
||||
|
||||
for frame_index in frame_indices:
|
||||
img = Image.fromarray(vr[frame_index].asnumpy())
|
||||
images_group.append(img)
|
||||
torch_imgs = self.transform(images_group)
|
||||
return torch_imgs
|
||||
|
||||
def save_video_frames(self, imgs, video_name, frames):
|
||||
if self.fps > 0:
|
||||
frame_paths = self.frame_paths_fps(video_name, frames)
|
||||
else:
|
||||
frame_paths = self.frame_paths(video_name)
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
|
||||
if not flag:
|
||||
block_size = imgs.size(0) // frames
|
||||
split_tensors = torch.split(imgs, block_size)
|
||||
to_pil = transforms.ToPILImage()
|
||||
images = [to_pil(arr) for arr in split_tensors]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth):
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths
|
||||
|
||||
def save_video_into_images(self, line):
|
||||
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
||||
if self.fps <= 0:
|
||||
self.num_segments = self.nframe
|
||||
else:
|
||||
self.num_segments = 0
|
||||
torch_imgs = self.read_video(video_path)
|
||||
img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments)
|
||||
return img_frame_paths
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question, answer = self.qa_template(line)
|
||||
message = [dict(type='text', value=self.SYS, role='system')]
|
||||
message.append(dict(type='text', value=question))
|
||||
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
||||
if video_llm:
|
||||
message.append(dict(type='video', value=video_path))
|
||||
else:
|
||||
img_frame_paths = self.save_video_into_images(line)
|
||||
for im in img_frame_paths:
|
||||
message.append(dict(type='image', value=im))
|
||||
message.append(dict(type='text', value='\nOnly give the best option.'))
|
||||
message.append(dict(type='text', value='Best option:(', role='assistant'))
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
|
||||
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
||||
|
||||
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
||||
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
||||
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
||||
|
||||
if not osp.exists(score_file):
|
||||
model = judge_kwargs.setdefault('model', 'chatgpt-0125')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
||||
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
||||
|
||||
data = load(eval_file)
|
||||
data_un = data[~pd.isna(data['prediction'])]
|
||||
|
||||
for idx in data_un['index']:
|
||||
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
||||
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
|
||||
options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
|
||||
answer_idx = -1
|
||||
for id, c in enumerate(options):
|
||||
if c == ans:
|
||||
answer_idx = id
|
||||
ans = f"({chr(ord('A') + answer_idx)}) {ans}"
|
||||
input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
|
||||
for id, option_content in enumerate(eval(input_item['candidates'])):
|
||||
input_item[chr(ord('A') + id)] = option_content
|
||||
if option_content == input_item['answer']:
|
||||
input_item['answer'] = chr(ord('A') + id)
|
||||
|
||||
if FAIL_MSG in pred:
|
||||
data.loc[idx, 'score'] = -1
|
||||
else:
|
||||
data.loc[idx, 'score'] = int(check_ans_with_model(
|
||||
pred, ans, model,
|
||||
input_item,
|
||||
'MVBench_MP4'
|
||||
))
|
||||
|
||||
rejected = [x for x in data['score'] if x == -1]
|
||||
|
||||
print(
|
||||
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
||||
f'failed to obtain the score for another {len(rejected)} questions. '
|
||||
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
||||
)
|
||||
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
dump(rating, tgt_file)
|
||||
return rating
|
||||
189
eval_mm/vlmevalkit/vlmeval/dataset/slidevqa.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import re
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from vlmeval.dataset.utils.judge_util import build_judge
|
||||
from vlmeval.smp import *
|
||||
from .image_base import ImageBaseDataset
|
||||
from .mmlongbench import concat_images, MMLongBench_auxeval, anls_compute
|
||||
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
def get_f1(gt, pred):
|
||||
gt_bow, pred_bow = gt.strip().split(), pred.strip().split()
|
||||
if not gt_bow or not pred_bow:
|
||||
return 0.0
|
||||
|
||||
recall = len([pred_e for pred_e in pred_bow if pred_e in gt_bow]) / len(gt_bow)
|
||||
precision = len([pred_e for pred_e in pred_bow if pred_e in gt_bow]) / len(pred_bow)
|
||||
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 1e-4 else 0.0
|
||||
return f1
|
||||
|
||||
|
||||
def SlideVQA_acc(result_file):
|
||||
data = load(result_file)
|
||||
anls_list, em_list, f1_list = list(), list(), list()
|
||||
for i in range(len(data)):
|
||||
item = data.iloc[i]
|
||||
if isinstance(item['answer'], float) and math.isnan(item['answer']):
|
||||
item['answer'] = 'Not answerable'
|
||||
|
||||
item['answer'] = re.sub('\n', '', item['answer']).lower()
|
||||
item['pred'] = str(item['pred']).lower()
|
||||
anls_score = anls_compute(item['answer'], item['pred'])
|
||||
em_score = (item['answer'].strip() == item['pred'].strip())
|
||||
f1_score = get_f1(item['answer'], item['pred'])
|
||||
anls_list.append(anls_score)
|
||||
em_list.append(em_score)
|
||||
f1_list.append(f1_score)
|
||||
print('---------------------')
|
||||
print(item['answer'], item['pred'], anls_score, em_score, f1_score)
|
||||
|
||||
data['anls'] = anls_list
|
||||
data['em'] = em_list
|
||||
data['f1'] = f1_list
|
||||
dump(data, result_file)
|
||||
|
||||
res = dict()
|
||||
res['category'], res['num'] = ['anls', 'EM', 'F1'], [len(data), len(data), len(data)]
|
||||
res['avg'] = [sum(anls_list) / len(data), sum(em_list) / len(data), sum(f1_list) / len(data)]
|
||||
res = pd.DataFrame(res)
|
||||
return res
|
||||
|
||||
|
||||
class SlideVQA(ImageBaseDataset):
|
||||
|
||||
TYPE = 'VQA'
|
||||
|
||||
DATASET_URL = {
|
||||
'SLIDEVQA_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/SLIDEVQA_MINI.tsv',
|
||||
'SLIDEVQA': 'https://opencompass.openxlab.space/utils/VLMEval/SLIDEVQA.tsv',
|
||||
}
|
||||
DATASET_MD5 = {
|
||||
'SLIDEVQA_MINI': '6d9a8d8814fa5b7669deb2af3a3208eb',
|
||||
'SLIDEVQA': '5e822c2f800e94c1e23badfd478326b6',
|
||||
}
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
'GPT4': (1, 1),
|
||||
'GPT4V': (1, 1),
|
||||
'GPT4V_HIGH': (1, 1),
|
||||
'GPT4o': (1, 1),
|
||||
'GPT4o_HIGH': (1, 1),
|
||||
'GPT4o_MINI': (1, 1),
|
||||
'XComposer2d5': (1, -1),
|
||||
'XComposer2_4KHD': (1, -1),
|
||||
'MiniCPM-Llama3-V-2_5': (1, 5),
|
||||
'InternVL-Chat-V1-5': (5, 2),
|
||||
}
|
||||
|
||||
def __init__(self, dataset, **kwargs):
|
||||
self.model_list = list(self.SUPPORTED_MODELS.keys())
|
||||
model_name = kwargs['model']
|
||||
if not listinstr(self.model_list, model_name):
|
||||
raise AssertionError("{} doesn't support the evaluation on SlideVQA.".format(model_name))
|
||||
super(SlideVQA, self).__init__(dataset)
|
||||
|
||||
self.is_api = True if listinstr(['GPT4'], model_name) else False
|
||||
self.max_pages = 120
|
||||
concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
|
||||
self.concat_num = concat_num
|
||||
self.column_num = column_num
|
||||
|
||||
def dump_image(self, origin_line):
|
||||
os.makedirs(self.img_root, exist_ok=True)
|
||||
|
||||
line = origin_line.copy()
|
||||
if not isinstance(line['image_path'], List):
|
||||
line['image_path'] = [line['image_path']]
|
||||
line['image_path'] = line['image_path'][:self.max_pages]
|
||||
|
||||
if 'image' in line:
|
||||
if isinstance(line['image'], list):
|
||||
tgt_path = []
|
||||
assert 'image_path' in line
|
||||
for img, im_name in zip(line['image'], line['image_path']):
|
||||
path = osp.join(self.img_root, im_name)
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(img, path)
|
||||
tgt_path.append(path)
|
||||
else:
|
||||
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
||||
if not read_ok(tgt_path):
|
||||
decode_base64_to_image_file(line['image'], tgt_path)
|
||||
tgt_path = [tgt_path]
|
||||
else:
|
||||
assert 'image_path' in line
|
||||
tgt_path = toliststr(line['image_path'])
|
||||
|
||||
if self.concat_num > 0 and not self.is_api:
|
||||
concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
|
||||
|
||||
old_tgt_path = tgt_path
|
||||
assert isinstance(old_tgt_path, list)
|
||||
if self.column_num != -1:
|
||||
tgt_path = [
|
||||
'_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
|
||||
for i in range(len(concatenated_images))
|
||||
]
|
||||
else:
|
||||
tgt_path = ['_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all.jpg']
|
||||
|
||||
for path, concatenated_image in zip(tgt_path, concatenated_images):
|
||||
if not read_ok(path):
|
||||
decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
|
||||
num_images, image_size = len(old_tgt_path), concatenated_image.size
|
||||
print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
|
||||
return tgt_path
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
logger = get_logger('Evaluation')
|
||||
model = judge_kwargs['model']
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
||||
|
||||
if osp.exists(storage):
|
||||
logger.warning(f'GPT scoring file {storage} already exists, will reuse it in SlideVQA_eval. ')
|
||||
else:
|
||||
data = load(eval_file)
|
||||
model = build_judge(max_tokens=128, **judge_kwargs)
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
tups = [(model, line) for line in lines]
|
||||
indices = [line['index'] for line in lines]
|
||||
|
||||
ans = {}
|
||||
if osp.exists(tmp_file):
|
||||
ans = load(tmp_file)
|
||||
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
||||
indices = [i for i in indices if i not in ans]
|
||||
|
||||
if len(indices):
|
||||
new_results = list()
|
||||
for model, line in tqdm(tups):
|
||||
res = MMLongBench_auxeval(model, line)
|
||||
new_results.append(res)
|
||||
|
||||
log_map, res_map, pred_map = {}, {}, {}
|
||||
all_inds = [line['index'] for line in lines]
|
||||
for k, v in zip(all_inds, new_results):
|
||||
log_map[k] = v['log']
|
||||
res_map[k] = v['res']
|
||||
pred_map[k] = v['pred']
|
||||
data['res'] = [res_map[idx] for idx in data['index']]
|
||||
data['log'] = [log_map[idx] for idx in data['index']]
|
||||
data['pred'] = [pred_map[idx] for idx in data['index']]
|
||||
dump(data, storage)
|
||||
|
||||
score = SlideVQA_acc(storage)
|
||||
score_pth = storage.replace('.xlsx', '_score.csv')
|
||||
|
||||
dump(score, score_pth)
|
||||
logger.info(f'SlideVQA successfully finished evaluating {eval_file}, results saved in {score_pth}')
|
||||
logger.info('Score: ')
|
||||
logger.info(score)
|
||||
639
eval_mm/vlmevalkit/vlmeval/dataset/tempcompass.py
Normal file
@@ -0,0 +1,639 @@
|
||||
import huggingface_hub
|
||||
from huggingface_hub import snapshot_download
|
||||
from ..smp import *
|
||||
from .video_concat_dataset import ConcatVideoDataset
|
||||
from .video_base import VideoBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from ..utils import track_progress_rich
|
||||
import torchvision.transforms as T
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from decord import VideoReader, cpu
|
||||
from .utils.tempcompass import *
|
||||
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
class TempCompass(ConcatVideoDataset):
|
||||
def __init__(self, dataset='TempCompass', nframe=0, fps=-1):
|
||||
self.DATASET_SETS[dataset] = ['TempCompass_MCQ', 'TempCompass_Captioning', 'TempCompass_YorN']
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['TempCompass']
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
result = super().evaluate(eval_file=eval_file, **judge_kwargs)
|
||||
suffix = eval_file.split('.')[-1]
|
||||
result = result.reset_index().rename(columns={'index': 'dim.task_type'})
|
||||
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
avg_dict = {}
|
||||
for idx, item in result.iterrows():
|
||||
dim, task_type = item['dim.task_type'].split('. ')
|
||||
if dim not in avg_dict:
|
||||
avg_dict[dim] = {'success': 0.0, 'overall': 0.0}
|
||||
if task_type not in avg_dict:
|
||||
avg_dict[task_type] = {'success': 0.0, 'overall': 0.0}
|
||||
if 'overall' not in avg_dict:
|
||||
avg_dict['overall'] = {'success': 0.0, 'overall': 0.0}
|
||||
avg_dict[dim]['success'] += item['success']
|
||||
avg_dict[dim]['overall'] += item['overall']
|
||||
avg_dict[task_type]['success'] += item['success']
|
||||
avg_dict[task_type]['overall'] += item['overall']
|
||||
avg_dict['overall']['success'] += item['success']
|
||||
avg_dict['overall']['overall'] += item['overall']
|
||||
result.loc[idx, 'acc'] = round(item['success'] / item['overall'] * 100, 2)
|
||||
for key, value in avg_dict.items():
|
||||
# 使用 loc 方法添加新行
|
||||
result.loc[len(result)] = {
|
||||
'dim.task_type': key,
|
||||
'success': value['success'],
|
||||
'overall': value['overall'],
|
||||
'acc': round(value['success'] / value['overall'] * 100, 2)
|
||||
}
|
||||
dump(result, score_file)
|
||||
return result
|
||||
|
||||
|
||||
class TempCompass_MCQ(VideoBaseDataset):
|
||||
|
||||
MD5 = '7efbb9e6d9dabacd22daf274852691dd'
|
||||
TYPE = 'Video-MCQ'
|
||||
|
||||
def __init__(self, dataset='TempCompass_MCQ', nframe=0, fps=-1):
|
||||
self.type_data_list = {
|
||||
'multi-choice': ('multi-choice.json', './videos', '.mp4'),
|
||||
'caption_matching': ('caption_matching.json', './videos', '.mp4'),
|
||||
}
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['TempCompass_MCQ']
|
||||
|
||||
def prepare_dataset(self, dataset_name='TempCompass_MCQ', repo_id='lmms-lab/TempCompass'):
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
|
||||
if not osp.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.MD5:
|
||||
return False
|
||||
|
||||
data = load(data_file)
|
||||
for idx, item in data.iterrows():
|
||||
if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
|
||||
return False
|
||||
return True
|
||||
|
||||
cache_path = get_cache_path(repo_id)
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
def read_parquet(pth):
|
||||
import pandas as pd
|
||||
for task_name in self.type_data_list.keys():
|
||||
if not osp.exists(osp.join(pth, f'{task_name}.json')):
|
||||
data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
|
||||
data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
|
||||
|
||||
def unzip_videos(pth):
|
||||
import zipfile
|
||||
if not osp.exists(osp.join(pth, 'videos')):
|
||||
zip_file = osp.join(pth, 'tempcompass_videos.zip')
|
||||
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
||||
zip_ref.extractall(pth)
|
||||
|
||||
def generate_tsv(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if osp.exists(data_file) and md5(data_file) == self.MD5:
|
||||
return
|
||||
self.data_list = []
|
||||
for k, v in self.type_data_list.items():
|
||||
with open(osp.join(pth, v[0]), 'r') as f:
|
||||
json_data = json.load(f)
|
||||
for data in json_data:
|
||||
self.data_list.append({
|
||||
'task_type': k,
|
||||
'prefix': v[1],
|
||||
'suffix': v[2],
|
||||
'video': data['video_id'],
|
||||
'question': data['question'].split('\n')[0],
|
||||
'answer': data['answer'],
|
||||
'dim': data['dim'],
|
||||
'candidates': data['question'].split('\n')[1:],
|
||||
})
|
||||
|
||||
data_df = pd.DataFrame(self.data_list)
|
||||
data_df = data_df.assign(index=range(len(data_df)))
|
||||
data_df.to_csv(data_file, sep='\t', index=False)
|
||||
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
||||
else:
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
||||
read_parquet(dataset_path)
|
||||
unzip_videos(dataset_path)
|
||||
generate_tsv(dataset_path)
|
||||
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
return dict(root=dataset_path, data_file=data_file)
|
||||
|
||||
def qa_template(self, data):
|
||||
question = data['question'] + '\n' + '\n'.join(eval(data['candidates']))
|
||||
answer = data['answer']
|
||||
return question, answer
|
||||
|
||||
def save_video_frames(self, line):
|
||||
vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
||||
vid = decord.VideoReader(vid_path)
|
||||
video_info = {
|
||||
'fps': vid.get_avg_fps(),
|
||||
'n_frames': len(vid),
|
||||
}
|
||||
if self.nframe > 0 and self.fps < 0:
|
||||
step_size = len(vid) / (self.nframe + 1)
|
||||
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
|
||||
frame_paths = self.frame_paths(line['video'])
|
||||
elif self.fps > 0:
|
||||
# not constrained by num_frames, get frames by fps
|
||||
total_duration = video_info['n_frames'] / video_info['fps']
|
||||
required_frames = int(total_duration * self.fps)
|
||||
step_size = video_info['fps'] / self.fps
|
||||
indices = [int(i * step_size) for i in range(required_frames)]
|
||||
frame_paths = self.frame_paths_fps(line['video'], len(indices))
|
||||
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
|
||||
if not flag:
|
||||
images = [vid[i].asnumpy() for i in indices]
|
||||
images = [Image.fromarray(arr) for arr in images]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth):
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths
|
||||
|
||||
def save_video_into_images(self, line):
|
||||
frame_paths = self.save_video_frames(line)
|
||||
return frame_paths
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question, answer = self.qa_template(line)
|
||||
message = []
|
||||
message.append(dict(type='text', value=question))
|
||||
video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
||||
if video_llm:
|
||||
message.append(dict(type='video', value=video_path))
|
||||
else:
|
||||
img_frame_paths = self.save_video_into_images(line)
|
||||
for im in img_frame_paths:
|
||||
message.append(dict(type='image', value=im))
|
||||
message.append(dict(type='text', value='\nPlease directly give the best option:'))
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
assert model in ['chatgpt-1106', 'exact_matching']
|
||||
judge_kwargs.update({
|
||||
"max_tokens": 128,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 1,
|
||||
})
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
if not osp.exists(score_file):
|
||||
data = load(eval_file)
|
||||
if model != 'exact_matching':
|
||||
model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
|
||||
else:
|
||||
model = None
|
||||
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
tups = [(model, line) for line in lines]
|
||||
indices = [line['index'] for line in lines]
|
||||
|
||||
ans = {}
|
||||
if osp.exists(tmp_file):
|
||||
ans = load(tmp_file)
|
||||
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
||||
indices = [i for i in indices if i not in ans]
|
||||
|
||||
if len(indices):
|
||||
_ = track_progress_rich(
|
||||
evaluate_tempcompass_mcq,
|
||||
tups,
|
||||
nproc=nproc,
|
||||
chunksize=nproc,
|
||||
keys=indices,
|
||||
save=tmp_file,
|
||||
)
|
||||
ans = load(tmp_file)
|
||||
for idx, item in data.iterrows():
|
||||
data.loc[idx, 'score'] = ans[idx]['rating']
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
return rating
|
||||
|
||||
|
||||
class TempCompass_Captioning(VideoBaseDataset):
|
||||
|
||||
MD5 = '35be9bf2581ea7767f02e9a8f37ae1ab'
|
||||
TYPE = 'Video-VQA'
|
||||
|
||||
def __init__(self, dataset='TempCompass_Captioning', nframe=0, fps=-1):
|
||||
self.type_data_list = {
|
||||
'captioning': ('captioning.json', './videos', '.mp4'),
|
||||
}
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['TempCompass_Captioning']
|
||||
|
||||
def prepare_dataset(self, dataset_name='TempCompass_Captioning', repo_id='lmms-lab/TempCompass'):
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
|
||||
if not osp.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.MD5:
|
||||
return False
|
||||
|
||||
data = load(data_file)
|
||||
for idx, item in data.iterrows():
|
||||
if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
|
||||
return False
|
||||
return True
|
||||
|
||||
cache_path = get_cache_path(repo_id)
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
def read_parquet(pth):
|
||||
import pandas as pd
|
||||
for task_name in self.type_data_list.keys():
|
||||
if not osp.exists(osp.join(pth, f'{task_name}.json')):
|
||||
data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
|
||||
data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
|
||||
|
||||
def unzip_videos(pth):
|
||||
import zipfile
|
||||
if not osp.exists(osp.join(pth, 'videos')):
|
||||
zip_file = osp.join(pth, 'tempcompass_videos.zip')
|
||||
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
||||
zip_ref.extractall(pth)
|
||||
|
||||
def generate_tsv(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if osp.exists(data_file) and md5(data_file) == self.MD5:
|
||||
return
|
||||
self.data_list = []
|
||||
for k, v in self.type_data_list.items():
|
||||
with open(osp.join(pth, v[0]), 'r') as f:
|
||||
json_data = json.load(f)
|
||||
for data in json_data:
|
||||
self.data_list.append({
|
||||
'task_type': k,
|
||||
'prefix': v[1],
|
||||
'suffix': v[2],
|
||||
'video': data['video_id'],
|
||||
'question': data['question'],
|
||||
'answer': data['answer'],
|
||||
'dim': data['dim'],
|
||||
'mc_question': data['mc_question'],
|
||||
'mc_answer': data['mc_answer'],
|
||||
})
|
||||
|
||||
data_df = pd.DataFrame(self.data_list)
|
||||
data_df = data_df.assign(index=range(len(data_df)))
|
||||
data_df.to_csv(data_file, sep='\t', index=False)
|
||||
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
||||
else:
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
||||
read_parquet(dataset_path)
|
||||
unzip_videos(dataset_path)
|
||||
generate_tsv(dataset_path)
|
||||
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
return dict(root=dataset_path, data_file=data_file)
|
||||
|
||||
def qa_template(self, data):
|
||||
question = data['question']
|
||||
answer = data['answer']
|
||||
return question, answer
|
||||
|
||||
def save_video_frames(self, line):
|
||||
vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
||||
vid = decord.VideoReader(vid_path)
|
||||
video_info = {
|
||||
'fps': vid.get_avg_fps(),
|
||||
'n_frames': len(vid),
|
||||
}
|
||||
if self.nframe > 0 and self.fps < 0:
|
||||
step_size = len(vid) / (self.nframe + 1)
|
||||
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
|
||||
frame_paths = self.frame_paths(line['video'])
|
||||
elif self.fps > 0:
|
||||
# not constrained by num_frames, get frames by fps
|
||||
total_duration = video_info['n_frames'] / video_info['fps']
|
||||
required_frames = int(total_duration * self.fps)
|
||||
step_size = video_info['fps'] / self.fps
|
||||
indices = [int(i * step_size) for i in range(required_frames)]
|
||||
frame_paths = self.frame_paths_fps(line['video'], len(indices))
|
||||
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
|
||||
if not flag:
|
||||
images = [vid[i].asnumpy() for i in indices]
|
||||
images = [Image.fromarray(arr) for arr in images]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth):
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths
|
||||
|
||||
def save_video_into_images(self, line):
|
||||
frame_paths = self.save_video_frames(line)
|
||||
return frame_paths
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question, answer = self.qa_template(line)
|
||||
message = []
|
||||
message.append(dict(type='text', value=question))
|
||||
video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
||||
if video_llm:
|
||||
message.append(dict(type='video', value=video_path))
|
||||
else:
|
||||
img_frame_paths = self.save_video_into_images(line)
|
||||
for im in img_frame_paths:
|
||||
message.append(dict(type='image', value=im))
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
assert model in ['chatgpt-1106', 'exact_matching']
|
||||
judge_kwargs.update({
|
||||
"max_tokens": 128,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 1,
|
||||
})
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
if not osp.exists(score_file):
|
||||
data = load(eval_file)
|
||||
if model != 'exact_matching':
|
||||
model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
|
||||
else:
|
||||
model = None
|
||||
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
tups = [(model, line) for line in lines]
|
||||
indices = [line['index'] for line in lines]
|
||||
|
||||
ans = {}
|
||||
if osp.exists(tmp_file):
|
||||
ans = load(tmp_file)
|
||||
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
||||
indices = [i for i in indices if i not in ans]
|
||||
|
||||
if len(indices):
|
||||
_ = track_progress_rich(
|
||||
evaluate_tempcompass_captioning,
|
||||
tups,
|
||||
nproc=nproc,
|
||||
chunksize=nproc,
|
||||
keys=indices,
|
||||
save=tmp_file,
|
||||
)
|
||||
ans = load(tmp_file)
|
||||
for idx, item in data.iterrows():
|
||||
data.loc[idx, 'score'] = ans[idx]['rating']
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
return rating
|
||||
|
||||
|
||||
class TempCompass_YorN(VideoBaseDataset):
|
||||
|
||||
MD5 = 'c72c046d7fa0e82c8cd7462f2e844ea8'
|
||||
TYPE = 'Video-Y/N'
|
||||
|
||||
def __init__(self, dataset='TempCompass_YorN', nframe=0, fps=-1):
|
||||
self.type_data_list = {
|
||||
'yes_no': ('yes_no.json', './videos', '.mp4'),
|
||||
}
|
||||
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
|
||||
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return ['TempCompass_YorN']
|
||||
|
||||
def prepare_dataset(self, dataset_name='TempCompass_YorN', repo_id='lmms-lab/TempCompass'):
|
||||
def check_integrity(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
|
||||
if not osp.exists(data_file):
|
||||
return False
|
||||
|
||||
if md5(data_file) != self.MD5:
|
||||
return False
|
||||
|
||||
data = load(data_file)
|
||||
for idx, item in data.iterrows():
|
||||
if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
|
||||
return False
|
||||
return True
|
||||
|
||||
cache_path = get_cache_path(repo_id)
|
||||
if cache_path is not None and check_integrity(cache_path):
|
||||
dataset_path = cache_path
|
||||
else:
|
||||
def read_parquet(pth):
|
||||
import pandas as pd
|
||||
for task_name in self.type_data_list.keys():
|
||||
if not osp.exists(osp.join(pth, f'{task_name}.json')):
|
||||
data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
|
||||
data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
|
||||
|
||||
def unzip_videos(pth):
|
||||
import zipfile
|
||||
if not osp.exists(osp.join(pth, 'videos')):
|
||||
zip_file = osp.join(pth, 'tempcompass_videos.zip')
|
||||
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
||||
zip_ref.extractall(pth)
|
||||
|
||||
def generate_tsv(pth):
|
||||
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
||||
if osp.exists(data_file) and md5(data_file) == self.MD5:
|
||||
return
|
||||
self.data_list = []
|
||||
for k, v in self.type_data_list.items():
|
||||
with open(osp.join(pth, v[0]), 'r') as f:
|
||||
json_data = json.load(f)
|
||||
for data in json_data:
|
||||
self.data_list.append({
|
||||
'task_type': k,
|
||||
'prefix': v[1],
|
||||
'suffix': v[2],
|
||||
'video': data['video_id'],
|
||||
'question': data['question'].split('\n')[0],
|
||||
'answer': data['answer'],
|
||||
'dim': data['dim']
|
||||
})
|
||||
|
||||
data_df = pd.DataFrame(self.data_list)
|
||||
data_df = data_df.assign(index=range(len(data_df)))
|
||||
data_df.to_csv(data_file, sep='\t', index=False)
|
||||
|
||||
if modelscope_flag_set():
|
||||
from modelscope import dataset_snapshot_download
|
||||
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
||||
else:
|
||||
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
||||
read_parquet(dataset_path)
|
||||
unzip_videos(dataset_path)
|
||||
generate_tsv(dataset_path)
|
||||
|
||||
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
||||
return dict(root=dataset_path, data_file=data_file)
|
||||
|
||||
def qa_template(self, data):
|
||||
question = data['question']
|
||||
answer = data['answer']
|
||||
return question, answer
|
||||
|
||||
def save_video_frames(self, line):
|
||||
vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
||||
vid = decord.VideoReader(vid_path)
|
||||
video_info = {
|
||||
'fps': vid.get_avg_fps(),
|
||||
'n_frames': len(vid),
|
||||
}
|
||||
if self.nframe > 0 and self.fps < 0:
|
||||
step_size = len(vid) / (self.nframe + 1)
|
||||
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
|
||||
frame_paths = self.frame_paths(line['video'])
|
||||
elif self.fps > 0:
|
||||
# not constrained by num_frames, get frames by fps
|
||||
total_duration = video_info['n_frames'] / video_info['fps']
|
||||
required_frames = int(total_duration * self.fps)
|
||||
step_size = video_info['fps'] / self.fps
|
||||
indices = [int(i * step_size) for i in range(required_frames)]
|
||||
frame_paths = self.frame_paths_fps(line['video'], len(indices))
|
||||
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
|
||||
if not flag:
|
||||
images = [vid[i].asnumpy() for i in indices]
|
||||
images = [Image.fromarray(arr) for arr in images]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth):
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths
|
||||
|
||||
def save_video_into_images(self, line):
|
||||
frame_paths = self.save_video_frames(line)
|
||||
return frame_paths
|
||||
|
||||
def build_prompt(self, line, video_llm):
|
||||
if isinstance(line, int):
|
||||
assert line < len(self)
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question, answer = self.qa_template(line)
|
||||
message = []
|
||||
message.append(dict(type='text', value=question))
|
||||
video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
||||
if video_llm:
|
||||
message.append(dict(type='video', value=video_path))
|
||||
else:
|
||||
img_frame_paths = self.save_video_into_images(line)
|
||||
for im in img_frame_paths:
|
||||
message.append(dict(type='image', value=im))
|
||||
message.append(dict(type='text', value='\nPlease answer yes or no:'))
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
assert model in ['chatgpt-1106', 'exact_matching']
|
||||
judge_kwargs.update({
|
||||
"max_tokens": 128,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 1,
|
||||
})
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
||||
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
if not osp.exists(score_file):
|
||||
data = load(eval_file)
|
||||
if model != 'exact_matching':
|
||||
model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
|
||||
else:
|
||||
model = None
|
||||
|
||||
lt = len(data)
|
||||
lines = [data.iloc[i] for i in range(lt)]
|
||||
tups = [(model, line) for line in lines]
|
||||
indices = [line['index'] for line in lines]
|
||||
|
||||
ans = {}
|
||||
if osp.exists(tmp_file):
|
||||
ans = load(tmp_file)
|
||||
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
||||
indices = [i for i in indices if i not in ans]
|
||||
|
||||
if len(indices):
|
||||
_ = track_progress_rich(
|
||||
evaluate_tempcompass_YorN,
|
||||
tups,
|
||||
nproc=nproc,
|
||||
chunksize=nproc,
|
||||
keys=indices,
|
||||
save=tmp_file,
|
||||
)
|
||||
ans = load(tmp_file)
|
||||
for idx, item in data.iterrows():
|
||||
data.loc[idx, 'score'] = ans[idx]['rating']
|
||||
dump(data, score_file)
|
||||
|
||||
rating = get_dimension_rating(score_file)
|
||||
return rating
|
||||
88
eval_mm/vlmevalkit/vlmeval/dataset/text_base.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from abc import abstractmethod
|
||||
from ..smp import *
|
||||
|
||||
|
||||
class TextBaseDataset:
|
||||
MODALITY = 'TEXT'
|
||||
DATASET_URL = {}
|
||||
DATASET_MD5 = {}
|
||||
|
||||
def __init__(self, dataset='MMBench', **kwargs):
|
||||
self.dataset_name = dataset
|
||||
|
||||
data = self.load_data(dataset)
|
||||
|
||||
data['index'] = [str(x) for x in data['index']]
|
||||
|
||||
if np.all([istype(x, int) for x in data['index']]):
|
||||
data['index'] = [int(x) for x in data['index']]
|
||||
|
||||
self.data = data
|
||||
self.post_build(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(self.data.iloc[idx])
|
||||
|
||||
def prepare_tsv(self, url, file_md5=None):
|
||||
data_root = LMUDataRoot()
|
||||
os.makedirs(data_root, exist_ok=True)
|
||||
update_flag = False
|
||||
file_name = url.split('/')[-1]
|
||||
data_path = osp.join(data_root, file_name)
|
||||
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
|
||||
pass
|
||||
else:
|
||||
warnings.warn('The dataset tsv is not downloaded')
|
||||
download_file(url, data_path)
|
||||
update_flag = True
|
||||
|
||||
if file_size(data_path, 'GB') > 1:
|
||||
local_path = data_path.replace('.tsv', '_local.tsv')
|
||||
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
|
||||
from ..tools import LOCALIZE
|
||||
LOCALIZE(data_path, local_path)
|
||||
data_path = local_path
|
||||
return load(data_path)
|
||||
|
||||
def dump_image(self, line):
|
||||
return []
|
||||
|
||||
def display(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
assert isinstance(line, pd.Series) or isinstance(line, dict)
|
||||
mmqa_display(line)
|
||||
|
||||
# Return a list of dataset names that are supported by this class, can override
|
||||
@classmethod
|
||||
def supported_datasets(cls):
|
||||
return list(cls.DATASET_URL)
|
||||
|
||||
# Given the dataset name, return the dataset as a pandas dataframe, can override
|
||||
def load_data(self, dataset):
|
||||
url = self.DATASET_URL[dataset]
|
||||
file_md5 = self.DATASET_MD5[dataset]
|
||||
return self.prepare_tsv(url, file_md5)
|
||||
|
||||
# Post built hook, will be called after the dataset is built, can override
|
||||
def post_build(self, dataset):
|
||||
pass
|
||||
|
||||
# Given one data record, return the built prompt (a multi-modal message), can override
|
||||
def build_prompt(self, line):
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question = line['question']
|
||||
|
||||
msgs = []
|
||||
msgs.append(dict(type='text', value=question))
|
||||
return msgs
|
||||
|
||||
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
|
||||
@abstractmethod
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
pass
|
||||
123
eval_mm/vlmevalkit/vlmeval/dataset/text_mcq.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from .text_base import TextBaseDataset
|
||||
from .utils import build_judge, DEBUG_MESSAGE
|
||||
from ..smp import *
|
||||
|
||||
|
||||
class TextMCQDataset(TextBaseDataset):
|
||||
TYPE = 'MCQ'
|
||||
|
||||
DATASET_URL = {}
|
||||
|
||||
DATASET_MD5 = {}
|
||||
|
||||
def build_prompt(self, line):
|
||||
|
||||
if isinstance(line, int):
|
||||
line = self.data.iloc[line]
|
||||
|
||||
question = line['question']
|
||||
options = {
|
||||
cand: line[cand]
|
||||
for cand in string.ascii_uppercase
|
||||
if cand in line and not pd.isna(line[cand])
|
||||
}
|
||||
options_prompt = 'Options:\n'
|
||||
for key, item in options.items():
|
||||
options_prompt += f'{key}. {item}\n'
|
||||
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
||||
prompt = ''
|
||||
if hint is not None:
|
||||
prompt += f'Hint: {hint}\n'
|
||||
prompt += f'Question: {question}\n'
|
||||
if len(options):
|
||||
prompt += options_prompt
|
||||
prompt += 'Please select the correct answer from the options above. \n'
|
||||
|
||||
msgs = []
|
||||
|
||||
msgs.append(dict(type='text', value=prompt))
|
||||
|
||||
return msgs
|
||||
|
||||
def evaluate(self, eval_file, **judge_kwargs):
|
||||
from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
|
||||
# assert dataset is not None
|
||||
dataset_map = {
|
||||
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
|
||||
'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
|
||||
}
|
||||
dataset = self.dataset_name
|
||||
if dataset in dataset_map:
|
||||
dataset = dataset_map[dataset]
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
circular = False
|
||||
|
||||
suffix = eval_file.split('.')[-1]
|
||||
model = judge_kwargs.get('model', 'exact_matching')
|
||||
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
||||
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
||||
name_str = name_str_map[model] if model in name_str_map else model
|
||||
|
||||
if model == 'exact_matching':
|
||||
model = None
|
||||
elif gpt_key_set():
|
||||
model = build_judge(**judge_kwargs)
|
||||
if not model.working():
|
||||
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
||||
warnings.warn(DEBUG_MESSAGE)
|
||||
model = None
|
||||
else:
|
||||
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
||||
model = None
|
||||
|
||||
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
||||
|
||||
data = load(eval_file)
|
||||
data = data.sort_values(by='index')
|
||||
data['prediction'] = [str(x) for x in data['prediction']]
|
||||
# If not choice label, then use lower case
|
||||
for k in data.keys():
|
||||
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
||||
|
||||
meta = self.data
|
||||
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
||||
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
||||
for k in data_map:
|
||||
assert k in meta_q_map, (
|
||||
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
||||
)
|
||||
|
||||
if circular:
|
||||
data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
||||
else:
|
||||
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
||||
|
||||
# load split
|
||||
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
||||
|
||||
# May have different report acc functions for different datasets
|
||||
if 'MMT' in dataset:
|
||||
acc = report_acc_MMT(data)
|
||||
else:
|
||||
acc = report_acc(data)
|
||||
|
||||
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
||||
dump(acc, score_file)
|
||||
|
||||
return acc
|
||||
|
||||
|
||||
class CustomTextMCQDataset(TextMCQDataset):
|
||||
|
||||
def load_data(self, dataset):
|
||||
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
||||
|
||||
if file_size(data_path, 'GB') > 1:
|
||||
local_path = data_path.replace('.tsv', '_local.tsv')
|
||||
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
|
||||
from ..tools import LOCALIZE
|
||||
LOCALIZE(data_path, local_path)
|
||||
data_path = local_path
|
||||
return load(data_path)
|
||||
9
eval_mm/vlmevalkit/vlmeval/dataset/utils/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .judge_util import build_judge, DEBUG_MESSAGE
|
||||
from .multiple_choice import extract_answer_from_item, prefetch_answer
|
||||
from .vqa_eval import levenshtein_distance
|
||||
|
||||
|
||||
__all__ = [
|
||||
'build_judge', 'extract_answer_from_item', 'prefetch_answer',
|
||||
'levenshtein_distance', 'DEBUG_MESSAGE',
|
||||
]
|
||||
@@ -0,0 +1,59 @@
|
||||
# CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy
|
||||
|
||||
## Introduction
|
||||
|
||||
Please refer to our [GitHub](https://github.com/AlibabaResearch/AdvancedLiterateMachinery/tree/main/Benchmarks/CC-OCR) for more information.
|
||||
|
||||
## Running Scripts
|
||||
|
||||
Once the environment is ready, execute the following script from the root directory of VLMEvalKit
|
||||
to perform inference and evaluation tasks in batch.
|
||||
|
||||
```shell
|
||||
MODEL_NAME="QwenVLMax"
|
||||
OUTPUT_DIR="/your/path/to/output_dir"
|
||||
|
||||
SUB_OUTPUT_DIR=${OUTPUT_DIR}/multi_scene_ocr
|
||||
python run.py --data CCOCR_MultiSceneOcr_Cord CCOCR_MultiSceneOcr_Funsd CCOCR_MultiSceneOcr_Iam CCOCR_MultiSceneOcr_ZhDoc CCOCR_MultiSceneOcr_ZhHandwriting CCOCR_MultiSceneOcr_Hieragent CCOCR_MultiSceneOcr_Ic15 CCOCR_MultiSceneOcr_Inversetext CCOCR_MultiSceneOcr_Totaltext CCOCR_MultiSceneOcr_ZhScene CCOCR_MultiSceneOcr_UgcLaion CCOCR_MultiSceneOcr_ZhDense CCOCR_MultiSceneOcr_ZhVertical --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
|
||||
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
|
||||
|
||||
SUB_OUTPUT_DIR=${OUTPUT_DIR}/multi_lan_ocr
|
||||
python run.py --data CCOCR_MultiLanOcr_Arabic CCOCR_MultiLanOcr_French CCOCR_MultiLanOcr_German CCOCR_MultiLanOcr_Italian CCOCR_MultiLanOcr_Japanese CCOCR_MultiLanOcr_Korean CCOCR_MultiLanOcr_Portuguese CCOCR_MultiLanOcr_Russian CCOCR_MultiLanOcr_Spanish CCOCR_MultiLanOcr_Vietnamese --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
|
||||
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
|
||||
|
||||
SUB_OUTPUT_DIR=${OUTPUT_DIR}/doc_parsing
|
||||
python run.py --data CCOCR_DocParsing_DocPhotoChn CCOCR_DocParsing_DocPhotoEng CCOCR_DocParsing_DocScanChn CCOCR_DocParsing_DocScanEng CCOCR_DocParsing_TablePhotoChn CCOCR_DocParsing_TablePhotoEng CCOCR_DocParsing_TableScanChn CCOCR_DocParsing_TableScanEng CCOCR_DocParsing_MolecularHandwriting CCOCR_DocParsing_FormulaHandwriting --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
|
||||
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
|
||||
|
||||
SUB_OUTPUT_DIR=${OUTPUT_DIR}/kie
|
||||
python run.py --data CCOCR_Kie_Sroie2019Word CCOCR_Kie_Cord CCOCR_Kie_EphoieScut CCOCR_Kie_Poie CCOCR_Kie_ColdSibr CCOCR_Kie_ColdCell --model ${MODEL_NAME} --work-dir ${SUB_OUTPUT_DIR} --verbose
|
||||
python vlmeval/dataset/utils/ccocr_evaluator/common.py ${SUB_OUTPUT_DIR}
|
||||
```
|
||||
|
||||
## Example Output
|
||||
The evaluation results will be saved in `${SUB_OUTPUT_DIR}/summary.md`. For example, for the KIE subset,
|
||||
the output is as follows:
|
||||
|
||||
| exp_name(f1_score) | COLD_CELL | COLD_SIBR | CORD | EPHOIE_SCUT | POIE | sroie2019_word | summary |
|
||||
|:-------------------|------------:|------------:|-------:|--------------:|-------:|-----------------:|----------:|
|
||||
| QwenVLMax | 81.01 | 72.46 | 69.33 | 71.2 | 60.85 | 76.37 | 71.87 |
|
||||
|
||||
|
||||
## Citation
|
||||
If you find our work helpful, feel free to give us a cite.
|
||||
|
||||
```
|
||||
@misc{yang2024ccocr,
|
||||
title={CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy},
|
||||
author={Zhibo Yang and Jun Tang and Zhaohai Li and Pengfei Wang and Jianqiang Wan and Humen Zhong and Xuejing Liu and Mingkun Yang and Peng Wang and Shuai Bai and LianWen Jin and Junyang Lin},
|
||||
year={2024},
|
||||
eprint={2412.02210},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV},
|
||||
url={https://arxiv.org/abs/2412.02210},
|
||||
}
|
||||
```
|
||||
|
||||
## Contact Us
|
||||
|
||||
If you have any questions, feel free to send an email to: wpf272043@alibaba-inc.com or xixing.tj@alibaba-inc.com
|
||||
@@ -0,0 +1,12 @@
|
||||
from .kie_evaluator import KieEvaluator
|
||||
from .doc_parsing_evaluator import ParsingEvaluator
|
||||
from .ocr_evaluator import OcrEvaluator
|
||||
from .common import summary
|
||||
|
||||
|
||||
evaluator_map_info = {
|
||||
"kie": KieEvaluator("kie"),
|
||||
"doc_parsing": ParsingEvaluator("doc_parsing"),
|
||||
"multi_lan_ocr": OcrEvaluator("multi_lan_ocr"),
|
||||
"multi_scene_ocr": OcrEvaluator("multi_scene_ocr")
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
def pick_response_text(json_path):
|
||||
"""
|
||||
"""
|
||||
try:
|
||||
with open(json_path, "r") as f:
|
||||
json_data = json.load(f)
|
||||
except Exception as e:
|
||||
print("--> file error: msg: {}, path: {}".format(e, json_path))
|
||||
return None
|
||||
|
||||
for required_key in ["model_name", "response"]:
|
||||
if required_key not in json_data:
|
||||
print("--> required key not exists, name: {}, path: {}".format(required_key, json_path))
|
||||
return None
|
||||
|
||||
model_name = json_data["model_name"]
|
||||
model_response = json_data["response"]
|
||||
|
||||
response_text = None
|
||||
if model_name.startswith("gpt") or model_name.startswith("o1"):
|
||||
response_text = model_response.get("data", {}).get("response", {}).get("choices", [{}])[0].get("message", {}).get("content", None) # noqa: E501
|
||||
elif model_name.startswith("local_"):
|
||||
response_text = model_response
|
||||
else:
|
||||
if model_name.startswith("claude"):
|
||||
content_list = model_response.get("content", None)
|
||||
elif model_name.startswith("gemini"):
|
||||
content_list = model_response.get("candidates", [{}])[0].get("content", {}).get("parts", None)
|
||||
elif model_name.startswith("qwen"):
|
||||
content_list = model_response.get("output", {}).get("choices", [{}])[0].get("message", {}).get("content", None) # noqa: E501
|
||||
else:
|
||||
raise NotImplementedError("The pick_response_text NOT implemented for model: {}".format(model_name))
|
||||
|
||||
if isinstance(content_list, list) and len(content_list) > 0:
|
||||
response_text = content_list[0].get("text", None)
|
||||
|
||||
if response_text is None:
|
||||
print("--> [error][{}] text pick error, path: {}".format(model_name, json_path))
|
||||
return response_text
|
||||
|
||||
|
||||
def load_response_from_dir(res_dir):
|
||||
"""
|
||||
"""
|
||||
response_info = {}
|
||||
for file_name in os.listdir(res_dir):
|
||||
file_path = os.path.abspath(os.path.join(res_dir, file_name))
|
||||
if not file_name.endswith(".json"):
|
||||
print("--> skip: result file should be a json: but got: {}".format(file_path))
|
||||
continue
|
||||
|
||||
response_text = pick_response_text(file_path)
|
||||
if response_text is None:
|
||||
continue
|
||||
|
||||
file_name_wo_ext, ext = os.path.splitext(file_name)
|
||||
response_info[file_name_wo_ext] = response_text
|
||||
return response_info
|
||||
|
||||
|
||||
class BaseMetric(object):
|
||||
""" BaseMetric """
|
||||
""" OCRMetric """
|
||||
def __init__(self, group_name, **kwargs):
|
||||
self.group_name = group_name
|
||||
self.kwargs = kwargs
|
||||
|
||||
def response_post_func(self, response_text, **kwargs):
|
||||
return response_text
|
||||
|
||||
@abstractmethod
|
||||
# Given the prediction and gt, return the evaluation results in the format of a dictionary
|
||||
# results should contain a 'summary' key, for example:
|
||||
# {
|
||||
# "summary": {
|
||||
# "f1-score": 99.99,
|
||||
# "metric_name": "metric_value" # used for summary,only metric info could be placed in this dict.
|
||||
# },
|
||||
# "your other info": "xxx"
|
||||
# }
|
||||
def evaluate(self, response_info, gt_info, normalize_func=None, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, pdt_res_dir, gt_info, with_response_ratio=True, **kwargs):
|
||||
if isinstance(pdt_res_dir, dict):
|
||||
raw_response_info = pdt_res_dir
|
||||
elif os.path.exists(pdt_res_dir) and os.path.isdir(pdt_res_dir):
|
||||
raw_response_info = load_response_from_dir(pdt_res_dir)
|
||||
else:
|
||||
return ValueError("invalid input: response dict or folder are required, but got {}".format(pdt_res_dir))
|
||||
|
||||
post_error_list, response_info = [], {}
|
||||
response_error_list = list(gt_info.keys() - raw_response_info.keys())
|
||||
for file_name, single_pdt_str in raw_response_info.items():
|
||||
single_pdt_str = self.response_post_func(single_pdt_str, **kwargs)
|
||||
if single_pdt_str is None:
|
||||
post_error_list.append(file_name)
|
||||
continue
|
||||
response_info[file_name] = single_pdt_str
|
||||
|
||||
meta_info = {
|
||||
"gt_total_num": len(gt_info), "pdt_total_num": len(response_info),
|
||||
"post_error_list": post_error_list, "response_error_list": response_error_list,
|
||||
}
|
||||
eval_info = self.evaluate(response_info, gt_info, **kwargs)
|
||||
|
||||
# add response_success_ratio
|
||||
if "summary" in eval_info and with_response_ratio:
|
||||
success_ratio = (len(response_info) + len(post_error_list)) / (len(gt_info) + 1e-9)
|
||||
eval_info["summary"].update({"response_success_ratio": success_ratio})
|
||||
return meta_info, eval_info
|
||||
|
||||
|
||||
def summary(index_path, exp_dir_base, is_weighted_sum=False):
|
||||
"""
|
||||
"""
|
||||
with open(index_path, "r") as f:
|
||||
data_list = json.load(f)
|
||||
|
||||
all_data_info = {}
|
||||
for data_info_item in data_list:
|
||||
data_name = data_info_item["dataset"]
|
||||
if not data_info_item.get("release", True):
|
||||
continue
|
||||
all_data_info[data_name] = data_info_item
|
||||
dataset_list = list(all_data_info.keys())
|
||||
summary_path = summary_multi_exp(exp_dir_base, dataset_list, is_weighted_sum=is_weighted_sum)
|
||||
return summary_path
|
||||
|
||||
|
||||
def summary_multi_exp(exp_dir_base, dataset_list=None, is_weighted_sum=False):
|
||||
"""
|
||||
"""
|
||||
if dataset_list is None:
|
||||
all_dataset_name = []
|
||||
for exp_name in os.listdir(exp_dir_base):
|
||||
dir_status_path = os.path.join(exp_dir_base, exp_name, "status.json")
|
||||
if not os.path.exists(dir_status_path):
|
||||
continue
|
||||
with open(dir_status_path, "r") as f:
|
||||
data_status_info = json.load(f)
|
||||
all_dataset_name.extend(data_status_info.keys())
|
||||
dataset_list = sorted(set(all_dataset_name))
|
||||
|
||||
# summary main code
|
||||
all_evaluate_info, _ = {}, 0
|
||||
for exp_name in os.listdir(exp_dir_base):
|
||||
dir_status_path = os.path.join(exp_dir_base, exp_name, "status.json")
|
||||
if not os.path.exists(dir_status_path):
|
||||
print("--> skip: status.json not exist: {}".format(dir_status_path))
|
||||
continue
|
||||
|
||||
with open(dir_status_path, "r") as f:
|
||||
all_status_info = json.load(f)
|
||||
|
||||
for data_name in dataset_list:
|
||||
total_num = all_status_info.get(data_name, {}).get("config", {}).get("num", "-1")
|
||||
summary_info = all_status_info.get(data_name, {}).get("evaluation", {}).get("summary", {})
|
||||
for metric_name, metric_value in summary_info.items():
|
||||
if metric_name not in all_evaluate_info:
|
||||
all_evaluate_info[metric_name] = {}
|
||||
if exp_name not in all_evaluate_info[metric_name]:
|
||||
all_evaluate_info[metric_name][exp_name] = {}
|
||||
all_evaluate_info[metric_name][exp_name][data_name] = (metric_value, total_num)
|
||||
|
||||
all_table_md = []
|
||||
for metric_name, metric_info in all_evaluate_info.items():
|
||||
formatted_time = time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time()))
|
||||
summary_line_list = []
|
||||
summary_key_name = "summary(weighted)" if is_weighted_sum else "summary"
|
||||
summary_head = [f"exp_name({metric_name}_{formatted_time})"] + dataset_list + [summary_key_name]
|
||||
for exp_name, data_eval_info in metric_info.items():
|
||||
summary_line = [exp_name, ]
|
||||
|
||||
all_metric_value = 0
|
||||
is_summary_valid, all_total_num, all_weighted_metric = True, 0, 0
|
||||
for data_name in dataset_list:
|
||||
metric_value, total_num = data_eval_info.get(data_name, ("-1", "-1"))
|
||||
summary_line.append("{:.2f}".format(float(metric_value) * 100))
|
||||
if str(metric_value) == "-1" or str(metric_value) == "-1":
|
||||
is_summary_valid = False
|
||||
continue
|
||||
|
||||
all_total_num += float(total_num)
|
||||
all_weighted_metric += float(total_num) * float(metric_value)
|
||||
all_metric_value += float(metric_value)
|
||||
|
||||
summary_value_valid = ((all_weighted_metric / (all_total_num + 1e-9)) * 100) if is_weighted_sum \
|
||||
else (all_metric_value / (len(dataset_list) + 1e-9) * 100)
|
||||
summary_value = "-" if not is_summary_valid else "{:.2f}".format(summary_value_valid)
|
||||
summary_line.append(summary_value)
|
||||
summary_line_list.append(summary_line)
|
||||
|
||||
md_table_info = tabulate(summary_line_list, headers=summary_head, tablefmt='pipe')
|
||||
all_table_md.append(md_table_info)
|
||||
|
||||
print("\n\n".join(all_table_md))
|
||||
summary_path = os.path.abspath(os.path.join(exp_dir_base, "summary.md"))
|
||||
with open(summary_path, "w") as f:
|
||||
f.write("\n\n".join(all_table_md))
|
||||
return summary_path
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python {} exp_base_dir".format(__file__))
|
||||
exit(-1)
|
||||
else:
|
||||
print('--> info: {}'.format(sys.argv))
|
||||
exp_base_dir = sys.argv[1]
|
||||
|
||||
summary_path = summary_multi_exp(exp_base_dir, dataset_list=None, is_weighted_sum=False)
|
||||
print("--> info: summary saved at : {}".format(summary_path))
|
||||
print("happy coding.")
|
||||
@@ -0,0 +1,256 @@
|
||||
import nltk
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
from collections import deque
|
||||
from apted.helpers import Tree
|
||||
from apted import APTED, Config
|
||||
|
||||
# local import
|
||||
from .common import BaseMetric
|
||||
|
||||
|
||||
# 移除指定的LaTeX命令
|
||||
patterns = [
|
||||
r'\\documentclass\{.*?\}',
|
||||
r'\\usepackage\[.*?\]\{.*?\}',
|
||||
r'\\usepackage\{.*?\}',
|
||||
r'\\geometry\{.*?\}',
|
||||
r'\\begin\{document\}',
|
||||
r'\\end\{document\}',
|
||||
r'\\noindent'
|
||||
]
|
||||
|
||||
|
||||
class TableTree(Tree):
|
||||
"""
|
||||
# Copyright 2020 IBM
|
||||
# Author: peter.zhong@au1.ibm.com
|
||||
# License: Apache 2.0 License.
|
||||
"""
|
||||
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
||||
self.tag = tag
|
||||
self.colspan = colspan
|
||||
self.rowspan = rowspan
|
||||
self.content = content
|
||||
self.children = list(children)
|
||||
|
||||
def bracket(self):
|
||||
"""Show tree using brackets notation"""
|
||||
if self.tag == "td":
|
||||
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
|
||||
self.tag,
|
||||
self.colspan,
|
||||
self.rowspan,
|
||||
self.content,
|
||||
)
|
||||
else:
|
||||
result = '"tag": %s' % self.tag
|
||||
for child in self.children:
|
||||
result += child.bracket()
|
||||
return "{{{}}}".format(result)
|
||||
|
||||
|
||||
class CustomConfig(Config):
|
||||
"""
|
||||
# Copyright 2020 IBM
|
||||
# Author: peter.zhong@au1.ibm.com
|
||||
# License: Apache 2.0 License.
|
||||
"""
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
# print(node1.tag)
|
||||
if (
|
||||
(node1.tag != node2.tag)
|
||||
or (node1.colspan != node2.colspan)
|
||||
or (node1.rowspan != node2.rowspan)
|
||||
):
|
||||
return 1.0
|
||||
if node1.tag == "td":
|
||||
if node1.content or node2.content:
|
||||
return nltk.edit_distance(node1.content, node2.content) / max(len(node1.content), len(node2.content))
|
||||
return 0.0
|
||||
|
||||
|
||||
class TEDS(object):
|
||||
"""Tree Edit Distance basead Similarity
|
||||
# Copyright 2020 IBM
|
||||
# Author: peter.zhong@au1.ibm.com
|
||||
# License: Apache 2.0 License.
|
||||
"""
|
||||
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
||||
assert isinstance(n_jobs, int) and (
|
||||
n_jobs >= 1
|
||||
), "n_jobs must be an integer greather than 1"
|
||||
self.structure_only = structure_only
|
||||
self.n_jobs = n_jobs
|
||||
self.ignore_nodes = ignore_nodes
|
||||
self.__tokens__ = []
|
||||
|
||||
def tokenize(self, node):
|
||||
"""Tokenizes table cells"""
|
||||
self.__tokens__.append("<%s>" % node.tag)
|
||||
if node.text is not None:
|
||||
self.__tokens__ += list(node.text)
|
||||
for n in node.getchildren():
|
||||
self.tokenize(n)
|
||||
if node.tag != "unk":
|
||||
self.__tokens__.append("</%s>" % node.tag)
|
||||
if node.tag != "td" and node.tail is not None:
|
||||
self.__tokens__ += list(node.tail)
|
||||
|
||||
def load_html_tree(self, node, parent=None):
|
||||
"""Converts HTML tree to the format required by apted"""
|
||||
global __tokens__
|
||||
if node.tag == "td":
|
||||
if self.structure_only:
|
||||
cell = []
|
||||
else:
|
||||
self.__tokens__ = []
|
||||
self.tokenize(node)
|
||||
cell = self.__tokens__[1:-1].copy()
|
||||
new_node = TableTree(
|
||||
node.tag,
|
||||
int(node.attrib.get("colspan", "1")),
|
||||
int(node.attrib.get("rowspan", "1")),
|
||||
cell,
|
||||
*deque(),
|
||||
)
|
||||
else:
|
||||
new_node = TableTree(node.tag, None, None, None, *deque())
|
||||
if parent is not None:
|
||||
parent.children.append(new_node)
|
||||
if node.tag != "td":
|
||||
for n in node.getchildren():
|
||||
self.load_html_tree(n, new_node)
|
||||
if parent is None:
|
||||
return new_node
|
||||
|
||||
def evaluate(self, pred, true):
|
||||
"""Computes TEDS score between the prediction and the ground truth of a
|
||||
given sample
|
||||
"""
|
||||
# try_import("lxml")
|
||||
from lxml import etree, html
|
||||
if (not pred) or (not true):
|
||||
return 0.0
|
||||
|
||||
parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
|
||||
pred = html.fromstring(pred, parser=parser)
|
||||
true = html.fromstring(true, parser=parser)
|
||||
if pred.xpath("body/table") and true.xpath("body/table"):
|
||||
pred = pred.xpath("body/table")[0]
|
||||
true = true.xpath("body/table")[0]
|
||||
if self.ignore_nodes:
|
||||
etree.strip_tags(pred, *self.ignore_nodes)
|
||||
etree.strip_tags(true, *self.ignore_nodes)
|
||||
n_nodes_pred = len(pred.xpath(".//*"))
|
||||
n_nodes_true = len(true.xpath(".//*"))
|
||||
n_nodes = max(n_nodes_pred, n_nodes_true)
|
||||
tree_pred = self.load_html_tree(pred)
|
||||
tree_true = self.load_html_tree(true)
|
||||
distance = APTED(
|
||||
tree_pred, tree_true, CustomConfig()
|
||||
).compute_edit_distance()
|
||||
return 1.0 - (float(distance) / n_nodes)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
class ParsingEvaluator(BaseMetric):
|
||||
def response_post_func(self, response_text, **kwargs):
|
||||
return response_text
|
||||
|
||||
def evaluate(self, response_info, gt_info, **kwargs):
|
||||
op = kwargs['op']
|
||||
if op == 'doc':
|
||||
score = self.eval_doc(response_info, gt_info)
|
||||
elif op == 'table':
|
||||
score = self.eval_table(response_info, gt_info)
|
||||
elif op in ['molecular', "formula"]:
|
||||
score = self.eval_formula(response_info, gt_info, op_name=op)
|
||||
else:
|
||||
raise ValueError(f'doc parsing unsupported op: {op}')
|
||||
|
||||
# summary info
|
||||
eval_info = {"summary": {"score": score}}
|
||||
return eval_info
|
||||
|
||||
def eval_doc(self, response_info, gt_info):
|
||||
results = []
|
||||
for img_name, gt in tqdm(gt_info.items()):
|
||||
if img_name not in response_info:
|
||||
results.append(0)
|
||||
continue
|
||||
|
||||
pred = response_info[img_name]
|
||||
for pattern in patterns:
|
||||
pred = re.sub(pattern, '', pred)
|
||||
|
||||
try:
|
||||
pred = pred.split('```')[1]
|
||||
except:
|
||||
pass
|
||||
|
||||
pred = pred.replace('```latex', '')
|
||||
pred = pred.replace('```', '')
|
||||
|
||||
pred = pred.replace(' ', '').replace('\n', '')
|
||||
gt = gt.replace(' ', '').replace('\n', '')
|
||||
|
||||
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
|
||||
results.append(1 - edit_dist)
|
||||
|
||||
score = sum(results) / len(results)
|
||||
return score
|
||||
|
||||
def eval_table(self, response_info, gt_info):
|
||||
teds = TEDS(structure_only=False, n_jobs=1)
|
||||
results = []
|
||||
for img_name, gt in tqdm(gt_info.items()):
|
||||
if img_name not in response_info:
|
||||
results.append(0)
|
||||
continue
|
||||
|
||||
pred = response_info[img_name]
|
||||
for pattern in patterns:
|
||||
pred = re.sub(pattern, '', pred)
|
||||
|
||||
try:
|
||||
pred = pred.split('```html')[1]
|
||||
except:
|
||||
pass
|
||||
|
||||
pred = pred.replace('```', '')
|
||||
pred = pred.replace(' ', '').replace('\n', '').replace(',', ',')
|
||||
gt = gt.replace(' ', '').replace('\n', '')
|
||||
|
||||
pred_html = '<html><body>{}</body></html>'.format(pred)
|
||||
gt_html = '<html><body>{}</body></html>'.format(gt)
|
||||
results.append(teds.evaluate(pred_html, gt_html))
|
||||
|
||||
score = sum(results) / len(results)
|
||||
return score
|
||||
|
||||
def eval_formula(self, response_info, gt_info, op_name='formula'):
|
||||
results = []
|
||||
for img_name, gt in tqdm(gt_info.items()):
|
||||
if img_name not in response_info:
|
||||
results.append(0)
|
||||
continue
|
||||
|
||||
pred = response_info[img_name]
|
||||
|
||||
if op_name == 'formula':
|
||||
pred = pred.replace("\n", " ").replace("```latex", "").replace("```", "").replace("\t", " ").replace(" ", "") # noqa: E501
|
||||
gt = gt.replace(" ", "")
|
||||
elif op_name == 'molecular':
|
||||
pred = pred.replace("\n", "").replace(" ", "").replace("<smiles>", "").replace("</smiles>", "")
|
||||
gt = gt.replace(" ", "")
|
||||
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
|
||||
results.append(1 - edit_dist)
|
||||
score = sum(results) / len(results)
|
||||
return score
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
@@ -0,0 +1,385 @@
|
||||
|
||||
"""
|
||||
Donut
|
||||
Copyright (c) 2022-present NAVER Corp.
|
||||
MIT License
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import zss
|
||||
from zss import Node
|
||||
from collections import Counter
|
||||
from nltk import edit_distance
|
||||
|
||||
# local import
|
||||
from .common import BaseMetric
|
||||
|
||||
|
||||
def flatten(data: dict):
|
||||
"""
|
||||
Convert Dictionary into Non-nested Dictionary
|
||||
Example:
|
||||
input(dict)
|
||||
{
|
||||
"menu": [
|
||||
{"name" : ["cake"], "count" : ["2"]},
|
||||
{"name" : ["juice"], "count" : ["1"]},
|
||||
]
|
||||
}
|
||||
output(list)
|
||||
[
|
||||
("menu.name", "cake"),
|
||||
("menu.count", "2"),
|
||||
("menu.name", "juice"),
|
||||
("menu.count", "1"),
|
||||
]
|
||||
"""
|
||||
flatten_data = list()
|
||||
|
||||
def _flatten(value, key=""):
|
||||
if type(value) is dict:
|
||||
for child_key, child_value in value.items():
|
||||
_flatten(child_value, f"{key}.{child_key}" if key else child_key)
|
||||
elif type(value) is list:
|
||||
for value_item in value:
|
||||
_flatten(value_item, key)
|
||||
else:
|
||||
flatten_data.append((key, value))
|
||||
|
||||
_flatten(data)
|
||||
return flatten_data
|
||||
|
||||
|
||||
def update_cost(node1: Node, node2: Node):
|
||||
"""
|
||||
Update cost for tree edit distance.
|
||||
If both are leaf node, calculate string edit distance between two labels (special token '<leaf>' will be ignored).
|
||||
If one of them is leaf node, cost is length of string in leaf node + 1.
|
||||
If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1
|
||||
"""
|
||||
label1 = node1.label
|
||||
label2 = node2.label
|
||||
label1_leaf = "<leaf>" in label1
|
||||
label2_leaf = "<leaf>" in label2
|
||||
if label1_leaf and label2_leaf:
|
||||
return edit_distance(label1.replace("<leaf>", ""), label2.replace("<leaf>", ""))
|
||||
elif not label1_leaf and label2_leaf:
|
||||
return 1 + len(label2.replace("<leaf>", ""))
|
||||
elif label1_leaf and not label2_leaf:
|
||||
return 1 + len(label1.replace("<leaf>", ""))
|
||||
else:
|
||||
return int(label1 != label2)
|
||||
|
||||
|
||||
def insert_and_remove_cost(node: Node):
|
||||
"""
|
||||
Insert and remove cost for tree edit distance.
|
||||
If leaf node, cost is length of label name.
|
||||
Otherwise, 1
|
||||
"""
|
||||
label = node.label
|
||||
if "<leaf>" in label:
|
||||
return len(label.replace("<leaf>", ""))
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def normalize_dict(data: Union[Dict, List, Any]):
|
||||
"""
|
||||
Sort by value, while iterate over element if data is list
|
||||
"""
|
||||
# if not data:
|
||||
# return {}
|
||||
|
||||
if isinstance(data, dict):
|
||||
new_data = dict()
|
||||
for key in sorted(data.keys(), key=lambda k: (len(k), k)):
|
||||
value = normalize_dict(data[key])
|
||||
if value:
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
new_data[key] = value
|
||||
|
||||
elif isinstance(data, list):
|
||||
if all(isinstance(item, dict) for item in data):
|
||||
new_data = []
|
||||
for item in data:
|
||||
item = normalize_dict(item)
|
||||
if item:
|
||||
new_data.append(item)
|
||||
else:
|
||||
new_data = [str(item).strip() for item in data if type(item) in {str, int, float} and str(item).strip()]
|
||||
else:
|
||||
new_data = [str(data).strip()]
|
||||
return new_data
|
||||
|
||||
|
||||
def cal_f1_all(preds, answers):
|
||||
"""
|
||||
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives,
|
||||
false negatives and false positives
|
||||
"""
|
||||
metric_info, error_info = {}, {}
|
||||
total_tp, total_fn_or_fp = 0, 0
|
||||
for file_name, answer in answers.items():
|
||||
sample_error_info = {"fp": [], "fn": [], "tp": []}
|
||||
pred = preds.get(file_name, {})
|
||||
pred, answer = flatten(normalize_dict(pred)), flatten(normalize_dict(answer))
|
||||
for field in pred:
|
||||
field_name = field[0]
|
||||
if field_name not in metric_info:
|
||||
metric_info[field_name] = {"total_tp": 0, "total_fn_or_fp": 0}
|
||||
if field in answer:
|
||||
total_tp += 1
|
||||
metric_info[field_name]["total_tp"] += 1
|
||||
sample_error_info["tp"].append(field)
|
||||
answer.remove(field)
|
||||
else:
|
||||
total_fn_or_fp += 1
|
||||
metric_info[field_name]["total_fn_or_fp"] += 1
|
||||
sample_error_info["fp"].append(field)
|
||||
|
||||
total_fn_or_fp += len(answer)
|
||||
for field in answer:
|
||||
field_name = field[0]
|
||||
if field_name not in metric_info:
|
||||
metric_info[field_name] = {"total_tp": 0, "total_fn_or_fp": 0}
|
||||
metric_info[field_name]["total_fn_or_fp"] += 1
|
||||
sample_error_info["fn"].append(field)
|
||||
|
||||
sample_error_num = sum([len(v) for k, v in sample_error_info.items() if k != "tp"])
|
||||
if sample_error_num > 0:
|
||||
sample_error_info["error_num"] = sample_error_num
|
||||
error_class_list = ["counter_" + x[0] for x in (sample_error_info["fn"] + sample_error_info["fp"])]
|
||||
counter = Counter(error_class_list)
|
||||
sample_error_info["error_info"] = dict(counter)
|
||||
error_info[file_name] = sample_error_info
|
||||
|
||||
# summary
|
||||
for field_name, field_info in metric_info.items():
|
||||
field_tp, field_fn_or_fp = field_info["total_tp"], field_info["total_fn_or_fp"]
|
||||
metric_info[field_name]["acc"] = field_tp / (field_tp + field_fn_or_fp / 2 + 1e-6)
|
||||
|
||||
print("donut_evaluator: total_tp: {}, total_fn_or_fp: {}, ptd_num: {}, gt_num: {}".format(total_tp, total_fn_or_fp,
|
||||
len(preds), len(answers)))
|
||||
error_info = {k: v for k, v in
|
||||
sorted(error_info.items(), key=lambda item: item[1].get("error_num", 0), reverse=True)}
|
||||
metric_info = {k: v for k, v in
|
||||
sorted(metric_info.items(), key=lambda item: item[1].get("total_fn_or_fp", 0), reverse=True)}
|
||||
return total_tp / (total_tp + total_fn_or_fp / 2 + 1e-6), metric_info, error_info
|
||||
|
||||
|
||||
def construct_tree_from_dict(data: Union[Dict, List], node_name: str = None):
|
||||
"""
|
||||
Convert Dictionary into Tree
|
||||
|
||||
Example:
|
||||
input(dict)
|
||||
|
||||
{
|
||||
"menu": [
|
||||
{"name" : ["cake"], "count" : ["2"]},
|
||||
{"name" : ["juice"], "count" : ["1"]},
|
||||
]
|
||||
}
|
||||
|
||||
output(tree)
|
||||
<root>
|
||||
|
|
||||
menu
|
||||
/ \
|
||||
<subtree> <subtree>
|
||||
/ | | \
|
||||
name count name count
|
||||
/ | | \
|
||||
<leaf>cake <leaf>2 <leaf>juice <leaf>1
|
||||
"""
|
||||
if node_name is None:
|
||||
node_name = "<root>"
|
||||
|
||||
node = Node(node_name)
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
kid_node = construct_tree_from_dict(value, key)
|
||||
node.addkid(kid_node)
|
||||
elif isinstance(data, list):
|
||||
if all(isinstance(item, dict) for item in data):
|
||||
for item in data:
|
||||
kid_node = construct_tree_from_dict(
|
||||
item,
|
||||
"<subtree>",
|
||||
)
|
||||
node.addkid(kid_node)
|
||||
else:
|
||||
for item in data:
|
||||
node.addkid(Node(f"<leaf>{item}"))
|
||||
else:
|
||||
raise Exception(data, node_name)
|
||||
return node
|
||||
|
||||
|
||||
def cal_acc(pred: dict, answer: dict):
|
||||
"""
|
||||
Calculate normalized tree edit distance(nTED) based accuracy.
|
||||
1) Construct tree from dict,
|
||||
2) Get tree distance with insert/remove/update cost,
|
||||
3) Divide distance with GT tree size (i.e., nTED),
|
||||
4) Calculate nTED based accuracy. (= max(1 - nTED, 0 ).
|
||||
"""
|
||||
pred = construct_tree_from_dict(normalize_dict(pred))
|
||||
answer = construct_tree_from_dict(normalize_dict(answer))
|
||||
val1 = zss.distance(
|
||||
pred,
|
||||
answer,
|
||||
get_children=zss.Node.get_children,
|
||||
insert_cost=insert_and_remove_cost,
|
||||
remove_cost=insert_and_remove_cost,
|
||||
update_cost=update_cost,
|
||||
return_operations=False,
|
||||
)
|
||||
val2 = zss.distance(
|
||||
construct_tree_from_dict(normalize_dict({})),
|
||||
answer,
|
||||
get_children=zss.Node.get_children,
|
||||
insert_cost=insert_and_remove_cost,
|
||||
remove_cost=insert_and_remove_cost,
|
||||
update_cost=update_cost,
|
||||
return_operations=False,
|
||||
)
|
||||
return max(0, 1 - val1 / val2)
|
||||
|
||||
|
||||
def cal_acc_all(pred_info, answer_info):
|
||||
acc_info, error_info = {}, {}
|
||||
for file_name, answer in answer_info.items():
|
||||
# if file_name not in pred_info:
|
||||
# print("---> error: pdt not found: {}".format(file_name))
|
||||
# continue
|
||||
pred = pred_info.get(file_name, {})
|
||||
acc = cal_acc(pred, answer)
|
||||
acc_info[file_name] = acc
|
||||
if acc < 1.0:
|
||||
error_info[file_name] = {"acc": acc, "pred": pred, "answer": answer}
|
||||
|
||||
error_info = {k: v for k, v in sorted(error_info.items(), key=lambda item: item[1].get("acc", 0))}
|
||||
acc_averge = sum(list(acc_info.values())) / (len(acc_info) + 1e-6)
|
||||
return acc_averge, error_info
|
||||
|
||||
|
||||
def normalize_values_of_nested_dict(d, normalize_func):
|
||||
"""
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
return {k: normalize_values_of_nested_dict(v, normalize_func) for k, v in d.items()}
|
||||
elif isinstance(d, list):
|
||||
return [normalize_values_of_nested_dict(x, normalize_func) if isinstance(x, dict) else x for x in d]
|
||||
elif isinstance(d, str):
|
||||
return normalize_func(d)
|
||||
else:
|
||||
return d
|
||||
|
||||
|
||||
def eval_donut(pdt_info, gt_info, normalize_func=None, data_name=None):
|
||||
"""
|
||||
"""
|
||||
if normalize_func is not None:
|
||||
print("--> info: normalize_func executed.")
|
||||
pdt_info = normalize_values_of_nested_dict(pdt_info, normalize_func)
|
||||
gt_info = normalize_values_of_nested_dict(gt_info, normalize_func)
|
||||
|
||||
f1_score, class_eval_info, error_info = cal_f1_all(pdt_info, gt_info)
|
||||
acc_average, acc_error_info = cal_acc_all(pdt_info, gt_info)
|
||||
eval_info = {"f1_score": f1_score, "acc": acc_average, "class_f1_score": class_eval_info,
|
||||
"f1_error_info": error_info, "acc_error_info": acc_error_info}
|
||||
print(data_name, "f1_score", f1_score, "acc", acc_average)
|
||||
return eval_info
|
||||
|
||||
|
||||
def post_process_to_json(qwen_info_str, file_name=None):
|
||||
try:
|
||||
if "```json" in qwen_info_str:
|
||||
if "```" not in qwen_info_str:
|
||||
qwen_info_str += "```"
|
||||
qwen_info_group = re.search(r'```json(.*?)```', qwen_info_str, re.DOTALL)
|
||||
json_str = qwen_info_group.group(1).strip().replace("\n", "")
|
||||
else:
|
||||
json_str = qwen_info_str.strip().replace("\n", "")
|
||||
json_data = json.loads(json_str)
|
||||
return json_data
|
||||
except Exception as err: # noqa: F841
|
||||
return None
|
||||
|
||||
|
||||
def fullwidth_to_halfwidth(text):
|
||||
# 全角转半角
|
||||
result = ''
|
||||
for char in text:
|
||||
code_point = ord(char)
|
||||
# 全角空格直接转化
|
||||
if code_point == 0x3000:
|
||||
code_point = 0x0020
|
||||
# 其他全角字符(除空格)转换为半角
|
||||
elif 0xFF01 <= code_point <= 0xFF5E:
|
||||
code_point -= 0xFEE0
|
||||
result += chr(code_point)
|
||||
result = result.replace("、", ",")
|
||||
return result
|
||||
|
||||
|
||||
def remove_unnecessary_spaces(text):
|
||||
# 去掉中文字符之间的空格
|
||||
text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])', '', text)
|
||||
# 去掉中文和英文、数字之间的空格
|
||||
text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[a-zA-Z0-9])', '', text)
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])\s+(?=[\u4e00-\u9fff])', '', text)
|
||||
# 去掉符号前的不必要空格,保留符号后的一个空格
|
||||
text = re.sub(r'(?<![0-9])\s*([,.!?:;])\s*', r'\1 ', text) # 非数字前后的符号
|
||||
# 在数字和英文之间添加空格
|
||||
text = re.sub(r'(?<=[0-9])(?=[a-zA-Z])', ' ', text)
|
||||
text = re.sub(r'(?<=[a-zA-Z])(?=[0-9])', ' ', text)
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text
|
||||
|
||||
|
||||
class KieEvaluator(BaseMetric):
|
||||
def response_post_func(self, response_text, **kwargs):
|
||||
response_text = post_process_to_json(response_text, file_name=kwargs.get('file_name', None))
|
||||
return response_text
|
||||
|
||||
def normalize_func(self, text, **kwargs):
|
||||
halfwidth_text = fullwidth_to_halfwidth(str(text))
|
||||
cleaned_text = remove_unnecessary_spaces(halfwidth_text)
|
||||
return cleaned_text
|
||||
|
||||
def evaluate(self, response_info, gt_info, **kwargs):
|
||||
"""
|
||||
response_info: dict: {"file_name_1": response, "file_name_2": gt}
|
||||
gt_info: dict: {"file_name_1": gt, "file_name_2": gt}
|
||||
kwargs: dataset index config: {'dataset': 'kie_benchmark_POIE', 'group': 'kie', 'op': 'poie', 'num': 250}
|
||||
"""
|
||||
# gt should be a dict for kie task, fix for VLMEvalKit
|
||||
for image_name, label_content in gt_info.items():
|
||||
if isinstance(label_content, str):
|
||||
gt_info[image_name] = json.loads(label_content)
|
||||
|
||||
response_info = normalize_values_of_nested_dict(response_info, self.normalize_func)
|
||||
gt_info = normalize_values_of_nested_dict(gt_info, self.normalize_func)
|
||||
|
||||
f1_score, class_eval_info, error_info = cal_f1_all(response_info, gt_info)
|
||||
acc_average, acc_error_info = cal_acc_all(response_info, gt_info)
|
||||
|
||||
# summary info
|
||||
summary_info = {"f1_score": f1_score, "acc": acc_average}
|
||||
eval_info = {"summary": summary_info, "class_f1_score": class_eval_info,
|
||||
"f1_error_info": error_info, "acc_error_info": acc_error_info}
|
||||
return eval_info
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
# local import
|
||||
from .common import BaseMetric
|
||||
|
||||
|
||||
def token_normalize(token_text, is_lower=False, is_alphanum_only=False):
|
||||
"""
|
||||
"""
|
||||
if is_lower:
|
||||
token_text = token_text.lower()
|
||||
if is_alphanum_only:
|
||||
token_text = re.sub('[^A-Za-z0-9]+', '', token_text)
|
||||
return token_text
|
||||
|
||||
|
||||
def text_normalize_and_tokenize(text, is_keep_blank=True, is_lower=True, is_alphanum_only=False):
|
||||
text = text.replace("\t", " ").replace("\n", " ").replace("###", "").replace("***", "")
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
if not is_keep_blank:
|
||||
text = text.replace(" ", "")
|
||||
text_tokens = text.split(" ") if is_keep_blank else list(text)
|
||||
text_token_normalized = [token_normalize(t, is_lower, is_alphanum_only) for t in text_tokens]
|
||||
text_token_normalized = [x for x in text_token_normalized if len(x) > 0]
|
||||
return text_token_normalized
|
||||
|
||||
|
||||
def evaluate_single_sample(gts, preds):
|
||||
right_num = 0
|
||||
gt_counter_info = dict(Counter(gts))
|
||||
pdt_counter_info = dict(Counter(preds))
|
||||
for gt_token, gt_count in gt_counter_info.items():
|
||||
pred_count = pdt_counter_info.get(gt_token, 0)
|
||||
right_num += min(gt_count, pred_count)
|
||||
return right_num
|
||||
|
||||
|
||||
def calculate_metrics(response_info, gt_info, is_verbose=False):
|
||||
"""
|
||||
"""
|
||||
macro_recall_list, macro_precision_list, macro_f1_list = [], [], []
|
||||
total_gt_num, total_pred_num, total_right_num = 0, 0, 0
|
||||
for file_name, fullbox_gts in gt_info.items():
|
||||
fullbox_preds = response_info.get(file_name, [])
|
||||
right_num = evaluate_single_sample(fullbox_gts, fullbox_preds)
|
||||
total_right_num += right_num
|
||||
total_gt_num += len(fullbox_gts)
|
||||
total_pred_num += len(fullbox_preds)
|
||||
|
||||
macro_recall = right_num / (len(fullbox_gts) + 1e-9)
|
||||
macro_precision = right_num / (len(fullbox_preds) + 1e-9)
|
||||
macro_f1 = 2 * macro_recall * macro_precision / (macro_recall + macro_precision + 1e-9)
|
||||
macro_recall_list.append(macro_recall)
|
||||
macro_precision_list.append(macro_precision)
|
||||
macro_f1_list.append(macro_f1)
|
||||
|
||||
# marco
|
||||
final_macro_recall = sum(macro_recall_list) / (len(macro_recall_list) + 1e-9)
|
||||
final_macro_precision = sum(macro_precision_list) / (len(macro_precision_list) + 1e-9)
|
||||
final_macro_f1 = sum(macro_f1_list) / (len(macro_f1_list) + 1e-9)
|
||||
|
||||
# micro
|
||||
recall_acc = total_right_num / (total_gt_num + 1e-9)
|
||||
preci_acc = total_right_num / (total_pred_num + 1e-9)
|
||||
hmean = 2 * recall_acc * preci_acc / (recall_acc + preci_acc + 1e-9)
|
||||
vbs_eval_result = {
|
||||
'macro_recall': final_macro_recall, 'macro_precision': final_macro_precision, 'macro_f1_score': final_macro_f1,
|
||||
'micro_recall': recall_acc, 'micro_precision': preci_acc, 'mirco_f1_score': hmean
|
||||
}
|
||||
eval_result = vbs_eval_result if is_verbose else {'macro_f1_score': final_macro_f1, 'mirco_f1_score': hmean}
|
||||
return eval_result
|
||||
|
||||
|
||||
class OcrEvaluator(BaseMetric):
|
||||
def response_post_func(self, response_text, **kwargs):
|
||||
return response_text
|
||||
|
||||
def evaluate(self, response_info, gt_info, **kwargs):
|
||||
# hard code here
|
||||
dataset_name = kwargs['dataset']
|
||||
is_word_level, is_lower, is_alphanum_only = True, True, False
|
||||
if dataset_name in ["Arabic", "Japanese", "Korean"] or "zh" in dataset_name:
|
||||
is_word_level = False
|
||||
if "multi_scene_ocr" in self.group_name and is_word_level:
|
||||
is_alphanum_only = True
|
||||
eval_config = {"word_level": is_word_level, "alphanum_only": is_alphanum_only, "lowercase": is_lower}
|
||||
|
||||
image_pdt_info, image_gt_info = {}, {}
|
||||
for file_name, gt_src in gt_info.items():
|
||||
pred_src = response_info.get(file_name, "")
|
||||
pdt_token_list = text_normalize_and_tokenize(
|
||||
str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
|
||||
gt_token_list = text_normalize_and_tokenize(
|
||||
str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
|
||||
image_pdt_info[file_name] = pdt_token_list
|
||||
image_gt_info[file_name] = gt_token_list
|
||||
eval_result = calculate_metrics(image_pdt_info, image_gt_info, is_verbose=False)
|
||||
return {"summary": eval_result, "metric_config": eval_config}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
682
eval_mm/vlmevalkit/vlmeval/dataset/utils/cgbench.py
Normal file
@@ -0,0 +1,682 @@
|
||||
from ...smp import *
|
||||
from .multiple_choice import extract_answer_from_item
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
FAIL_MSG = "Failed to obtain answer via API."
|
||||
|
||||
frame_tmpl = "frame-{}-of-{}.jpg"
|
||||
|
||||
sys_prompt_open_eval_step_1 = (
|
||||
"You will be provided with a question, a model's prediction, and the ground "
|
||||
"truth answer for this question.\n"
|
||||
"Your task is to judge whether the model's prediction is correct based on the "
|
||||
"meaning of the two texts.\n"
|
||||
"In most cases, this can be done by determining if the meaning of the model's "
|
||||
"prediction is consistent with, or contains, the ground truth answer. However, "
|
||||
"in some cases where the two texts differ, it may represent different "
|
||||
"descriptions of the same visual scene, in which case visual information is "
|
||||
"needed for further judgment.\n"
|
||||
"Therefore, I hope you:\n"
|
||||
"- Output 0, if the model's prediction and the ground truth answer are neither "
|
||||
"consistent nor related by inclusion, with fundamentally different meanings.\n"
|
||||
"- Output 1, if the meaning of the model's prediction and the ground truth "
|
||||
"answer is consistent, or if the model's prediction meaningfully contains the "
|
||||
"ground truth answer.\n"
|
||||
"- Output 2, if the model's prediction and ground truth are not consistent or "
|
||||
"inclusive, but may be different descriptions of the same visual scene, "
|
||||
"requiring visual information for further judgment.\n"
|
||||
"Only output the answer in the following format:\n\n"
|
||||
'```json\n{"result": choice}\n```\n\n'
|
||||
"The choice is either 0, 1, or 2 as specified above."
|
||||
)
|
||||
|
||||
sys_prompt_open_eval_step_2 = (
|
||||
"You will be provided with a question, a model's prediction, and the sampling "
|
||||
"frames of the clue intervals related to this question.\n"
|
||||
"Your task is to determine whether the model has answered the question "
|
||||
"correctly based on the visual information provided.\n"
|
||||
"Therefore, I hope you:\n"
|
||||
"- Output 0, if the model's prediction does not correctly answer the question.\n"
|
||||
"- Output 1, if the model's prediction correctly answers the question.\n"
|
||||
"Only output the answer in the following format without output extra "
|
||||
"explanation:\n\n"
|
||||
'```json\n{"result": choice}\n```\n\n'
|
||||
"The choice is either 0 or 1 as specified above."
|
||||
)
|
||||
|
||||
FAIL_MSG = "Failed to obtain answer via API."
|
||||
|
||||
# '10-20', '20-30', '30-40', '40-50', '50-60'
|
||||
DURATIONS = ["0 ~ 10", "10 ~ 20", "20 ~ 30", "30 ~ 40", "40 ~ 50", "50 ~ 60", "60+"]
|
||||
|
||||
DOMAINS = [
|
||||
"Life Record",
|
||||
"Music & TV show",
|
||||
"Instruction & Knowledge",
|
||||
"Driving",
|
||||
"Embodied Expert",
|
||||
"Humor/funny",
|
||||
"Electonic/Social Gaming",
|
||||
"Security & Health",
|
||||
"Sports & Exercise",
|
||||
"Special Scenes",
|
||||
"Art & Culture",
|
||||
"GUI",
|
||||
"News",
|
||||
"Animal & Pet",
|
||||
]
|
||||
|
||||
SUB_CATEGORIES = [
|
||||
"Time Cognition",
|
||||
"Hallucination",
|
||||
"Entity Perception",
|
||||
"2D Spatial Perception",
|
||||
"Time Perception",
|
||||
"Scene Perception",
|
||||
"Text Perception",
|
||||
"Event Cognition",
|
||||
"Entity Cognition",
|
||||
"Text Cognition",
|
||||
"Event Perception",
|
||||
"Scene Cognition",
|
||||
]
|
||||
|
||||
|
||||
def get_dimention_rating_open_ended(data_path):
|
||||
# 读取数据
|
||||
df = load(data_path)
|
||||
|
||||
df = df[df["score"] != -1]
|
||||
|
||||
# 将秒转换为分钟并分配到对应区间
|
||||
df["duration_minutes"] = df["duration"] / 60
|
||||
df["duration_range"] = pd.cut(
|
||||
df["duration_minutes"], bins=[-np.inf, 10, 20, 30, 40, 50, 60, np.inf], labels=DURATIONS
|
||||
)
|
||||
|
||||
# 初始化结果字典
|
||||
result = {
|
||||
"overall": 0,
|
||||
"duration": {k: 0 for k in DURATIONS},
|
||||
"domain": {k: 0 for k in DOMAINS},
|
||||
"sub_category": {k: 0 for k in SUB_CATEGORIES},
|
||||
}
|
||||
|
||||
# Overall
|
||||
result["overall"] = round(df["score"].mean(), 4)
|
||||
|
||||
# Duration
|
||||
for dur in DURATIONS:
|
||||
dur_scores = df[df["duration_range"] == dur]["score"]
|
||||
result["duration"][dur] = round(dur_scores.mean(), 4) if not dur_scores.empty else 0
|
||||
|
||||
# Domain
|
||||
for domain in DOMAINS:
|
||||
domain_scores = df[df["domain"] == domain]["score"]
|
||||
result["domain"][domain] = round(domain_scores.mean(), 4) if not domain_scores.empty else 0
|
||||
|
||||
# Sub-category
|
||||
for sub_cat in SUB_CATEGORIES:
|
||||
sub_cat_scores = df[df["sub_category"] == sub_cat]["score"]
|
||||
result["sub_category"][sub_cat] = round(sub_cat_scores.mean(), 4) if not sub_cat_scores.empty else 0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_dimention_rating_mcq_grouding(data_path):
|
||||
|
||||
# 读取数据
|
||||
df = load(data_path)
|
||||
|
||||
# df.loc[(df['task_mode'] == 'miou') & (df['score'] == -1), 'score'] = 0
|
||||
|
||||
df = df[df["score"] != -1]
|
||||
|
||||
# 将秒转换为分钟并分配到对应区间
|
||||
df["duration_minutes"] = df["duration"] / 60
|
||||
df["duration_range"] = pd.cut(
|
||||
df["duration_minutes"], bins=[-np.inf, 10, 20, 30, 40, 50, 60, np.inf], labels=DURATIONS
|
||||
)
|
||||
|
||||
# 初始化结果字典
|
||||
result = {
|
||||
metric: {
|
||||
"overall": 0,
|
||||
"duration": {k: 0 for k in DURATIONS},
|
||||
"domain": {k: 0 for k in DOMAINS},
|
||||
"sub_category": {k: 0 for k in SUB_CATEGORIES},
|
||||
}
|
||||
for metric in ["long_acc", "clue_acc", "miou", "CRR", "acc@iou", "rec@iou"]
|
||||
}
|
||||
|
||||
# 计算基础指标
|
||||
for metric in ["long_acc", "clue_acc", "miou"]:
|
||||
metric_df = df[df["task_mode"] == metric]
|
||||
|
||||
# Overall
|
||||
result[metric]["overall"] = round(metric_df["score"].mean(), 4)
|
||||
|
||||
# Duration
|
||||
for dur in DURATIONS:
|
||||
dur_scores = metric_df[metric_df["duration_range"] == dur]["score"]
|
||||
result[metric]["duration"][dur] = round(dur_scores.mean(), 4) if not dur_scores.empty else 0
|
||||
|
||||
# Domain
|
||||
for domain in DOMAINS:
|
||||
domain_scores = metric_df[metric_df["domain"] == domain]["score"]
|
||||
result[metric]["domain"][domain] = round(domain_scores.mean(), 4) if not domain_scores.empty else 0
|
||||
|
||||
# Sub-category
|
||||
for sub_cat in SUB_CATEGORIES:
|
||||
sub_cat_scores = metric_df[metric_df["sub_category"] == sub_cat]["score"]
|
||||
result[metric]["sub_category"][sub_cat] = round(sub_cat_scores.mean(), 4) if not sub_cat_scores.empty else 0
|
||||
|
||||
# 计算复合指标 CRR
|
||||
def calculate_crr(scores):
|
||||
long_acc = scores[scores["task_mode"] == "long_acc"]["score"].mean()
|
||||
clue_acc = scores[scores["task_mode"] == "clue_acc"]["score"].mean()
|
||||
return round(min(long_acc, clue_acc) / clue_acc, 4) if clue_acc != 0 else 0
|
||||
|
||||
# Overall CRR
|
||||
result["CRR"]["overall"] = calculate_crr(df)
|
||||
|
||||
# Duration CRR
|
||||
for dur in DURATIONS:
|
||||
dur_df = df[df["duration_range"] == dur]
|
||||
result["CRR"]["duration"][dur] = calculate_crr(dur_df)
|
||||
|
||||
# Domain CRR
|
||||
for domain in DOMAINS:
|
||||
domain_df = df[df["domain"] == domain]
|
||||
result["CRR"]["domain"][domain] = calculate_crr(domain_df)
|
||||
|
||||
# Sub-category CRR
|
||||
for sub_cat in SUB_CATEGORIES:
|
||||
sub_cat_df = df[df["sub_category"] == sub_cat]
|
||||
result["CRR"]["sub_category"][sub_cat] = calculate_crr(sub_cat_df)
|
||||
|
||||
# 计算 acc@iou
|
||||
def calculate_acc_at_iou_threshold(scores, threshold):
|
||||
|
||||
miou_qids = set(scores[scores["task_mode"] == "miou"]["qid"])
|
||||
|
||||
long_acc_qids = set(scores[scores["task_mode"] == "long_acc"]["qid"])
|
||||
|
||||
valid_qids = miou_qids & long_acc_qids
|
||||
|
||||
miou_positive = set(scores[(scores["task_mode"] == "miou") & (scores["score"] > threshold)]["qid"])
|
||||
|
||||
long_acc_positive = scores[
|
||||
(scores["task_mode"] == "long_acc") & (scores["qid"].isin(miou_positive)) & (scores["score"] == 1)
|
||||
]
|
||||
|
||||
acc_at_iou_threshold = len(long_acc_positive) / len(valid_qids) if len(valid_qids) > 0 else 0
|
||||
return round(acc_at_iou_threshold, 4)
|
||||
|
||||
def calculate_acc_at_iou(scores):
|
||||
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
acc_at_iou_values = [calculate_acc_at_iou_threshold(scores, threshold) for threshold in thresholds]
|
||||
|
||||
return round(sum(acc_at_iou_values) / len(acc_at_iou_values), 4)
|
||||
|
||||
# Overall acc@iou
|
||||
result["acc@iou"]["overall"] = calculate_acc_at_iou(df)
|
||||
|
||||
# Duration acc@iou
|
||||
for dur in DURATIONS:
|
||||
dur_df = df[df["duration_range"] == dur]
|
||||
result["acc@iou"]["duration"][dur] = calculate_acc_at_iou(dur_df)
|
||||
|
||||
# Domain acc@iou
|
||||
for domain in DOMAINS:
|
||||
domain_df = df[df["domain"] == domain]
|
||||
result["acc@iou"]["domain"][domain] = calculate_acc_at_iou(domain_df)
|
||||
|
||||
# Sub-category acc@iou
|
||||
for sub_cat in SUB_CATEGORIES:
|
||||
sub_cat_df = df[df["sub_category"] == sub_cat]
|
||||
result["acc@iou"]["sub_category"][sub_cat] = calculate_acc_at_iou(sub_cat_df)
|
||||
|
||||
# 计算 rec@iou
|
||||
def calculate_rec_at_iou_threshold(scores, threshold):
|
||||
# 获取所有 miou 类型的数据
|
||||
miou_scores = scores[scores["task_mode"] == "miou"]
|
||||
|
||||
# 计算 miou score 大于 threshold 的数量
|
||||
miou_positive = miou_scores[miou_scores["score"] > threshold]
|
||||
|
||||
# 计算比例
|
||||
rec_at_iou = len(miou_positive) / len(miou_scores) if len(miou_scores) > 0 else 0
|
||||
|
||||
return round(rec_at_iou, 4)
|
||||
|
||||
def calculate_rec_at_iou(scores):
|
||||
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
rec_at_iou_values = [calculate_rec_at_iou_threshold(scores, threshold) for threshold in thresholds]
|
||||
|
||||
return round(sum(rec_at_iou_values) / len(rec_at_iou_values), 4)
|
||||
|
||||
# Overall rec@iou
|
||||
result["rec@iou"]["overall"] = calculate_rec_at_iou(df)
|
||||
|
||||
# Duration rec@iou
|
||||
for dur in DURATIONS:
|
||||
dur_df = df[df["duration_range"] == dur]
|
||||
result["rec@iou"]["duration"][dur] = calculate_rec_at_iou(dur_df)
|
||||
|
||||
# Domain rec@iou
|
||||
for domain in DOMAINS:
|
||||
domain_df = df[df["domain"] == domain]
|
||||
result["rec@iou"]["domain"][domain] = calculate_rec_at_iou(domain_df)
|
||||
|
||||
# Sub-category rec@iou
|
||||
for sub_cat in SUB_CATEGORIES:
|
||||
sub_cat_df = df[df["sub_category"] == sub_cat]
|
||||
result["rec@iou"]["sub_category"][sub_cat] = calculate_rec_at_iou(sub_cat_df)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def milliseconds_to_seconds(milliseconds):
|
||||
return milliseconds / 1000
|
||||
|
||||
|
||||
def sample_frames_clue_average(clues_time_intervals, frame_num, fps):
|
||||
# 计算每个线索区间的时长
|
||||
clues_frame_intervals = [(round(interval[0] * fps), round(interval[1] * fps)) for interval in clues_time_intervals]
|
||||
clue_durations = [interval[1] - interval[0] for interval in clues_frame_intervals]
|
||||
total_duration = sum(clue_durations)
|
||||
# 如果 frame_num 的数量大于等于总帧数, 则直接返回全部帧
|
||||
if frame_num >= total_duration:
|
||||
return [frame for interval in clues_frame_intervals for frame in range(interval[0], interval[1])]
|
||||
frames_per_clue = [int(frame_num * (duration / total_duration)) for duration in clue_durations]
|
||||
frame_indices = []
|
||||
for i, (interval, num_frames) in enumerate(zip(clues_frame_intervals, frames_per_clue)):
|
||||
num_frames = max(1, num_frames)
|
||||
seg_size = (interval[1] - interval[0]) / num_frames
|
||||
clue_frame_indices = [int(interval[0] + seg_size / 2 + seg_size * idx) for idx in range(num_frames)]
|
||||
frame_indices.extend(clue_frame_indices)
|
||||
return frame_indices
|
||||
|
||||
|
||||
def merge_intervals(intervals):
|
||||
"""
|
||||
Merge overlapping intervals in a list.
|
||||
Assumes each interval is a list [start, end].
|
||||
"""
|
||||
if not intervals:
|
||||
return []
|
||||
|
||||
# Sort intervals by start time
|
||||
intervals.sort(key=lambda x: x[0])
|
||||
|
||||
merged = [intervals[0]]
|
||||
|
||||
for current in intervals[1:]:
|
||||
last_merged = merged[-1]
|
||||
|
||||
# Check if there is an overlap
|
||||
if current[0] <= last_merged[1]:
|
||||
# Merge the current interval with the last one
|
||||
last_merged[1] = max(last_merged[1], current[1])
|
||||
else:
|
||||
# No overlap, add current interval
|
||||
merged.append(current)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def calculate_intervals_iou(intervals1, intervals2):
|
||||
"""
|
||||
Calculate the IoU of two lists of intervals.
|
||||
Each list contains intervals represented as [start, end].
|
||||
"""
|
||||
# Merge overlapping intervals in both lists
|
||||
merged1 = merge_intervals(intervals1)
|
||||
merged2 = merge_intervals(intervals2)
|
||||
|
||||
# Calculate total length of intervals for both lists
|
||||
def total_length(merged_intervals):
|
||||
return sum(end - start for start, end in merged_intervals)
|
||||
|
||||
length1 = total_length(merged1)
|
||||
length2 = total_length(merged2)
|
||||
|
||||
# Calculate intersection length
|
||||
intersection_length = 0
|
||||
for interval1 in merged1:
|
||||
for interval2 in merged2:
|
||||
intersection_start = max(interval1[0], interval2[0])
|
||||
intersection_end = min(interval1[1], interval2[1])
|
||||
intersection_length += max(0, intersection_end - intersection_start)
|
||||
# Calculate union length
|
||||
union_length = length1 + length2 - intersection_length
|
||||
# IoU is intersection divided by union
|
||||
iou = intersection_length / union_length if union_length > 0 else 0
|
||||
return iou
|
||||
|
||||
|
||||
def post_process(response, right_answer, task_mode, duration):
|
||||
result = -1
|
||||
|
||||
if response:
|
||||
# 找到 ```json 和 ``` 的位置
|
||||
json_start = response.find("```json")
|
||||
json_end = response.find("```", json_start + len("```json"))
|
||||
|
||||
# 如果找到了 json 内容
|
||||
if json_start != -1 and json_end != -1:
|
||||
json_content = response[json_start + len("```json"):json_end].strip()
|
||||
else:
|
||||
json_content = ""
|
||||
|
||||
if json_content:
|
||||
if task_mode in ["long_acc", "clue_acc"]:
|
||||
json_content = re.sub(r"(?<=:\s)([A-Za-z_]\w*)", r'"\1"', json_content)
|
||||
|
||||
try:
|
||||
model_result = json.loads(json_content)["result"]
|
||||
|
||||
if task_mode in ["long_acc", "clue_acc"]:
|
||||
result = 1 if right_answer == model_result else 0
|
||||
elif task_mode == "miou":
|
||||
if not isinstance(model_result, list):
|
||||
return -1
|
||||
if not isinstance(model_result[0], list):
|
||||
model_result = [model_result]
|
||||
|
||||
need_duration = all(interval[0] <= 1 and interval[1] <= 1 for interval in model_result)
|
||||
|
||||
if need_duration:
|
||||
model_result = [[interval[0] * duration, interval[1] * duration] for interval in model_result]
|
||||
|
||||
right_answer = eval(right_answer)
|
||||
|
||||
result = calculate_intervals_iou(right_answer, model_result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in parsing JSON: {e}, {json_content}")
|
||||
|
||||
if result == -1:
|
||||
if task_mode in ["long_acc", "clue_acc"]:
|
||||
# 检查是否存在大写字母 A-H,认为其为模型答案
|
||||
matches = re.findall(r"\b[A-H]\b", response)
|
||||
if matches:
|
||||
result = 1 if right_answer in matches else 0
|
||||
elif task_mode == "miou":
|
||||
# 提取所有实数,进行配对
|
||||
numbers = re.findall(r"-?\d+\.?\d*", response)
|
||||
if len(numbers) < 2:
|
||||
result = -1
|
||||
else:
|
||||
if len(numbers) % 2 != 0:
|
||||
numbers = numbers[:-1]
|
||||
model_result = [[float(numbers[i]), float(numbers[i + 1])] for i in range(0, len(numbers), 2)]
|
||||
|
||||
if type(right_answer) is str:
|
||||
right_answer = eval(right_answer)
|
||||
|
||||
result = calculate_intervals_iou(right_answer, model_result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_timestampes(frame_indices, fps):
|
||||
seconds = list(map(lambda x: str(round(x / fps, 4)), frame_indices))
|
||||
timestamps = ", ".join(seconds)
|
||||
return "A total of {frame_num} frames are sampled. Their corresponding timestamps are:\n\n{timestamps}\n\n".format(
|
||||
frame_num=len(frame_indices), timestamps=timestamps
|
||||
)
|
||||
|
||||
|
||||
def post_process_open(response):
|
||||
model_result = -1
|
||||
|
||||
if response and response != FAIL_MSG:
|
||||
json_start = response.find("```json")
|
||||
json_end = response.find("```", json_start + len("```json"))
|
||||
|
||||
# 如果找到了 json 内容
|
||||
if json_start != -1 and json_end != -1:
|
||||
json_content = response[json_start + len("```json"):json_end].strip()
|
||||
else:
|
||||
json_content = ""
|
||||
|
||||
if json_content:
|
||||
try:
|
||||
model_result = json.loads(json_content)["result"]
|
||||
except Exception as e:
|
||||
print(f"Error in parsing JSON: {e}, {json_content}")
|
||||
|
||||
if model_result == -1:
|
||||
model_result = response
|
||||
|
||||
return model_result
|
||||
|
||||
|
||||
def post_process_eval_open(response, step):
|
||||
|
||||
model_result = -1
|
||||
|
||||
if response and response != FAIL_MSG:
|
||||
|
||||
json_start = response.find("```json")
|
||||
json_end = response.find("```", json_start + len("```json"))
|
||||
|
||||
if json_start != -1 and json_end != -1:
|
||||
json_content = response[json_start + len("```json"):json_end].strip()
|
||||
else:
|
||||
json_content = ""
|
||||
|
||||
if json_content:
|
||||
try:
|
||||
model_result = json.loads(json_content)["result"]
|
||||
except Exception as e:
|
||||
print(f"Error in parsing JSON: {e}, {json_content}")
|
||||
return -1
|
||||
if model_result == -1:
|
||||
if step == 1:
|
||||
match = re.search(r"[012]", response)
|
||||
if match:
|
||||
model_result = int(match.group())
|
||||
else:
|
||||
match = re.search(r"[01]", response)
|
||||
if match:
|
||||
model_result = int(match.group())
|
||||
|
||||
return model_result
|
||||
|
||||
|
||||
def eval_open_first(model, line):
|
||||
|
||||
user_prompt = ""
|
||||
|
||||
user_prompt += f"Question: {line['question']}\n\n"
|
||||
|
||||
user_prompt += f"The ground truth answer is '{line['answer']}'\n\n"
|
||||
|
||||
user_prompt += f"The model's prediction is '{line['model_result']}'\n\n"
|
||||
|
||||
result = model.generate(user_prompt)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def save_step_1_steps(data, step_1_results):
|
||||
|
||||
# 处理所有结果
|
||||
data["step_1_result"] = data["qid"].map(lambda x: post_process_eval_open(step_1_results[x], 1))
|
||||
|
||||
# 条件更新
|
||||
mask = data["step_1_result"].isin([-1, 0, 1])
|
||||
data.loc[mask, "step_2_result"] = data.loc[mask, "step_1_result"]
|
||||
data.loc[mask, "score"] = data.loc[mask, "step_1_result"]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def eval_open_second(model, line, frame_paths):
|
||||
|
||||
user_prompt = ""
|
||||
|
||||
user_prompt += f"Question: {line['question']}\n\n"
|
||||
|
||||
user_prompt += f"The model's prediction is '{line['model_result']}'\n\n"
|
||||
|
||||
result = model.generate([user_prompt] + frame_paths)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def save_step_2_steps(data, step_1_results):
|
||||
|
||||
# 处理所有结果
|
||||
data["score"] = data["qid"].map(lambda x: post_process_eval_open(step_1_results[x], 2))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def clue_frame_paths(clue_frame_root, qid, num_frames=8):
|
||||
frame_root = osp.join(clue_frame_root, str(qid))
|
||||
os.makedirs(frame_root, exist_ok=True)
|
||||
return [osp.join(frame_root, frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
|
||||
|
||||
|
||||
def save_clue_video_frames(data_root, clue_frame_root, video, uid, clue_intervals=None, num_frames=8, fps=-1):
|
||||
|
||||
if type(uid) is str:
|
||||
uid = str(uid)
|
||||
|
||||
vid_path = osp.join(data_root, video)
|
||||
vid = decord.VideoReader(vid_path)
|
||||
vid_fps = vid.get_avg_fps()
|
||||
|
||||
if clue_intervals is not None:
|
||||
# 1. 合并重叠区间
|
||||
merged_intervals = merge_intervals(clue_intervals)
|
||||
|
||||
if num_frames > 0 and fps < 0:
|
||||
# 2. 基于clue_intervals均匀抽帧
|
||||
indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
|
||||
frame_paths = clue_frame_paths(clue_frame_root, uid, len(indices))
|
||||
|
||||
# 保存帧
|
||||
flag = np.all([osp.exists(p) for p in frame_paths])
|
||||
if not flag:
|
||||
images = [vid[i].asnumpy() for i in indices]
|
||||
images = [Image.fromarray(arr) for arr in images]
|
||||
for im, pth in zip(images, frame_paths):
|
||||
if not osp.exists(pth):
|
||||
im.save(pth)
|
||||
|
||||
return frame_paths, indices, vid_fps
|
||||
|
||||
|
||||
def get_chunk_number(filename):
|
||||
try:
|
||||
num = filename.split("chunk_")[1].split(".zip")[0]
|
||||
return int(num)
|
||||
except:
|
||||
return float('inf')
|
||||
|
||||
|
||||
def unzip_hf_zip(pth):
|
||||
|
||||
import zipfile
|
||||
|
||||
target_dir = pth
|
||||
|
||||
if os.path.exists(f"{target_dir}/cg_videos_720p") and os.path.exists(f"{target_dir}/cg_subtitles")\
|
||||
and os.path.exists(f"{target_dir}/cg_clue_videos"):
|
||||
print("all exists")
|
||||
return
|
||||
|
||||
video_zip_files = [
|
||||
os.path.join(target_dir, file)
|
||||
for file in os.listdir(target_dir)
|
||||
if file.endswith(".zip") and file.startswith("video")
|
||||
]
|
||||
|
||||
video_zip_files = sorted(video_zip_files, key=lambda x: get_chunk_number(os.path.basename(x)))
|
||||
|
||||
videos_temp_zip = os.path.join(target_dir, "videos_merged.zip")
|
||||
|
||||
print("Merging video files ...")
|
||||
|
||||
with open(videos_temp_zip, "wb") as outfile:
|
||||
for video_zip_file in tqdm(video_zip_files, desc="Merging videos"):
|
||||
with open(video_zip_file, "rb") as infile:
|
||||
outfile.write(infile.read())
|
||||
|
||||
print("Extracting video files...")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(videos_temp_zip, "r") as zip_ref:
|
||||
|
||||
total_files = len(zip_ref.namelist())
|
||||
|
||||
for file in tqdm(zip_ref.namelist(), desc="Extracting", total=total_files):
|
||||
zip_ref.extract(file, target_dir)
|
||||
|
||||
print(f"Successfully extracted to {target_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error during extraction: {e}")
|
||||
finally:
|
||||
|
||||
if os.path.exists(videos_temp_zip):
|
||||
os.remove(videos_temp_zip)
|
||||
print("Cleaned up temporary video file")
|
||||
|
||||
clue_video_zip_files = [
|
||||
os.path.join(target_dir, file)
|
||||
for file in os.listdir(target_dir)
|
||||
if file.endswith(".zip") and file.startswith("clue_video")
|
||||
]
|
||||
|
||||
clue_video_zip_files = sorted(clue_video_zip_files, key=lambda x: get_chunk_number(os.path.basename(x)))
|
||||
|
||||
clue_videos_temp_zip = os.path.join(target_dir, "clue_videos_merged.zip")
|
||||
|
||||
print("Merging clue video files ...")
|
||||
|
||||
with open(clue_videos_temp_zip, "wb") as outfile:
|
||||
for clue_video_zip_file in tqdm(clue_video_zip_files, desc="Merging clue_videos"):
|
||||
with open(clue_video_zip_file, "rb") as infile:
|
||||
outfile.write(infile.read())
|
||||
|
||||
print("Extracting clue video files...")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(clue_videos_temp_zip, "r") as zip_ref:
|
||||
|
||||
total_files = len(zip_ref.namelist())
|
||||
|
||||
for file in tqdm(zip_ref.namelist(), desc="Extracting", total=total_files):
|
||||
zip_ref.extract(file, target_dir)
|
||||
|
||||
print(f"Successfully extracted to {target_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error during extraction: {e}")
|
||||
finally:
|
||||
|
||||
if os.path.exists(clue_videos_temp_zip):
|
||||
os.remove(clue_videos_temp_zip)
|
||||
print("Cleaned up temporary clue video file")
|
||||
|
||||
print("Extracting subtitle files ...")
|
||||
|
||||
subtitles_zip = os.path.join(target_dir, "subtitles.zip")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(subtitles_zip, "r") as zip_ref:
|
||||
|
||||
total_files = len(zip_ref.namelist())
|
||||
|
||||
for file in tqdm(zip_ref.namelist(), desc="Extracting", total=total_files):
|
||||
zip_ref.extract(file, target_dir)
|
||||
|
||||
print(f"Successfully extracted to {target_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error during extraction: {e}")
|
||||
13
eval_mm/vlmevalkit/vlmeval/dataset/utils/crpe.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import json
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def is_correct(predict, answer):
|
||||
# predict是标准答案 answer是预测
|
||||
if len(answer) == 1:
|
||||
return answer[0] == predict[0]
|
||||
elif len(answer) != 1 and answer[0] in ['A', 'B', 'C', 'D']:
|
||||
return answer[0] == predict[0]
|
||||
elif len(answer) != 1 and answer[0] not in ['A', 'B', 'C', 'D']:
|
||||
return predict[4:].lower() in answer.lower()
|
||||
54
eval_mm/vlmevalkit/vlmeval/dataset/utils/hrbench.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from ...smp import *
|
||||
import os
|
||||
|
||||
|
||||
def report_acc_hrbench(df):
|
||||
cycle_group = df.groupby('cycle_category')
|
||||
result_dic = defaultdict(list)
|
||||
avg_dic = defaultdict(int)
|
||||
|
||||
count = 0
|
||||
for key, data_value in cycle_group:
|
||||
count += 1
|
||||
_, resp_dic = hrbench_score(data_value)
|
||||
|
||||
for task_type, accuracy in resp_dic.items():
|
||||
result_dic['cycle'].append(key)
|
||||
result_dic['type'].append(task_type)
|
||||
result_dic['accuracy'].append(accuracy)
|
||||
|
||||
avg_dic[task_type] += accuracy
|
||||
for task_type, accuracy in avg_dic.items():
|
||||
result_dic['cycle'].append('Average')
|
||||
result_dic['type'].append(task_type)
|
||||
result_dic['accuracy'].append(accuracy / count)
|
||||
result_pd = pd.DataFrame(result_dic)
|
||||
|
||||
return result_pd
|
||||
|
||||
|
||||
def hrbench_score(data):
|
||||
ret = defaultdict(list)
|
||||
resp_dic = {}
|
||||
category_list = set(data['category'])
|
||||
score_dict = defaultdict(list)
|
||||
|
||||
for i in range(len(data)):
|
||||
d = data.iloc[i]
|
||||
category = d['category']
|
||||
gpt_score = d['hit']
|
||||
score_dict[category].append(gpt_score)
|
||||
score_dict['all'].append(gpt_score)
|
||||
|
||||
all_acc = np.mean(score_dict['all'])
|
||||
ret['type'].append('all')
|
||||
ret['acc'].append(all_acc)
|
||||
resp_dic['all'] = all_acc
|
||||
for cate in category_list:
|
||||
acc = np.mean(score_dict[cate])
|
||||
ret['type'].append(cate)
|
||||
ret['acc'].append(acc)
|
||||
|
||||
resp_dic[cate] = acc
|
||||
|
||||
return pd.DataFrame(ret), resp_dic
|
||||
49
eval_mm/vlmevalkit/vlmeval/dataset/utils/judge_util.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
from ...smp import load_env
|
||||
|
||||
INTERNAL = os.environ.get('INTERNAL', 0)
|
||||
|
||||
|
||||
def build_judge(**kwargs):
|
||||
from ...api import OpenAIWrapper, SiliconFlowAPI
|
||||
model = kwargs.pop('model', None)
|
||||
kwargs.pop('nproc', None)
|
||||
load_env()
|
||||
LOCAL_LLM = os.environ.get('LOCAL_LLM', None)
|
||||
if LOCAL_LLM is None:
|
||||
model_map = {
|
||||
'gpt-4-turbo': 'gpt-4-1106-preview',
|
||||
'gpt-4-0613': 'gpt-4-0613',
|
||||
'gpt-4-0125': 'gpt-4-0125-preview',
|
||||
'gpt-4-0409': 'gpt-4-turbo-2024-04-09',
|
||||
'chatgpt-1106': 'gpt-3.5-turbo-1106',
|
||||
'chatgpt-0125': 'gpt-3.5-turbo-0125',
|
||||
'gpt-4o': 'gpt-4o-2024-05-13',
|
||||
'gpt-4o-0806': 'gpt-4o-2024-08-06',
|
||||
'gpt-4o-mini': 'gpt-4o-mini-2024-07-18',
|
||||
'qwen-7b': 'Qwen/Qwen2.5-7B-Instruct',
|
||||
'qwen-72b': 'Qwen/Qwen2.5-72B-Instruct',
|
||||
'deepseek': 'deepseek-ai/DeepSeek-V2.5',
|
||||
}
|
||||
model_version = model_map[model]
|
||||
else:
|
||||
model_version = LOCAL_LLM
|
||||
|
||||
if model in ['qwen-7b', 'qwen-72b', 'deepseek']:
|
||||
model = SiliconFlowAPI(model_version, **kwargs)
|
||||
else:
|
||||
model = OpenAIWrapper(model_version, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
DEBUG_MESSAGE = """
|
||||
To debug the OpenAI API, you can try the following scripts in python:
|
||||
```python
|
||||
from vlmeval.api import OpenAIWrapper
|
||||
model = OpenAIWrapper('gpt-4o', verbose=True)
|
||||
msgs = [dict(type='text', value='Hello!')]
|
||||
code, answer, resp = model.generate_inner(msgs)
|
||||
print(code, answer, resp)
|
||||
```
|
||||
You cam see the specific error if the API call fails.
|
||||
"""
|
||||
@@ -1,10 +1,6 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os.path as osp
|
||||
from vlmeval.evaluate.misc import build_judge
|
||||
from vlmeval.smp import *
|
||||
from vlmeval.utils import track_progress_rich
|
||||
from ...smp import *
|
||||
|
||||
rule_dict = {
|
||||
'llava_bench_conv': {'role': 'Assistant', 'prompt': 'We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.'}, # noqa: E501
|
||||
@@ -67,54 +63,3 @@ def LLaVABench_score(data):
|
||||
ret['VLM Score'].append(np.mean(sub['score']) * 10)
|
||||
ret['GPT4 Score'].append(np.mean(sub['gpt4_score']) * 10)
|
||||
return pd.DataFrame(ret)
|
||||
|
||||
|
||||
def LLaVABench_eval(eval_file, **judge_kwargs):
|
||||
suffix = '.' + eval_file.split('.')[-1]
|
||||
record_file = eval_file.replace(suffix, '_openai_result' + suffix)
|
||||
score_file = eval_file.replace(suffix, '_score.csv')
|
||||
nproc = judge_kwargs.pop('nproc', 4)
|
||||
|
||||
if not osp.exists(record_file):
|
||||
data = load(eval_file)
|
||||
lines = [data.iloc[i] for i in range(len(data))]
|
||||
model = build_judge(
|
||||
temperature=0.2,
|
||||
system_prompt='You are a helpful and precise assistant for checking the quality of the answer.',
|
||||
**judge_kwargs)
|
||||
prompts = [build_prompt(line) for line in lines]
|
||||
tups = [(model, prompt) for prompt in prompts]
|
||||
scores = track_progress_rich(LLaVABench_atomeval, tups, nproc=nproc, chunksize=nproc)
|
||||
data['gpt4_score'] = [x[0] for x in scores]
|
||||
data['score'] = [x[1] for x in scores]
|
||||
dump(data, record_file)
|
||||
|
||||
data = load(record_file)
|
||||
ret = LLaVABench_score(data).round(1)
|
||||
print(ret)
|
||||
dump(ret, score_file)
|
||||
return ret
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='LLaVABench Evaluation. ')
|
||||
parser.add_argument('data', type=str, help='The question set for inference, in excel / tsv / json format. ')
|
||||
parser.add_argument(
|
||||
'--model', type=str, help='The LLM (GPT) used for inference. ', default='gpt-4-turbo',
|
||||
choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613', 'gpt-4-0314'])
|
||||
parser.add_argument('--nproc', type=int, default=4)
|
||||
parser.add_argument('--verbose', action='store_true')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
load_env()
|
||||
args = parse_args()
|
||||
judge_kwargs = dict(model=args.model, nproc=args.nproc, verbose=args.verbose)
|
||||
if 'OPENAI_API_KEY_JUDGE' in os.environ and os.environ['OPENAI_API_KEY_JUDGE']:
|
||||
judge_kwargs['key'] = os.environ['OPENAI_API_KEY_JUDGE']
|
||||
if 'OPENAI_API_BASE_JUDGE' in os.environ and os.environ['OPENAI_API_BASE_JUDGE']:
|
||||
judge_kwargs['api_base'] = os.environ['OPENAI_API_BASE_JUDGE']
|
||||
|
||||
LLaVABench_eval(eval_file=args.data, **judge_kwargs)
|
||||
150
eval_mm/vlmevalkit/vlmeval/dataset/utils/logicvista.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import pandas as pd
|
||||
|
||||
# from colorama import Fore, Back, Style
|
||||
from ...smp import *
|
||||
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
|
||||
def build_prompt_logicvista(line):
|
||||
question = line['question']
|
||||
prediction = str(line['prediction'])
|
||||
tmpl = (
|
||||
"You are a information extractor that extracts multiple choice letter answer choices "
|
||||
"from a paragraph that contains the answer choice and sometimes explaination of why that "
|
||||
"choice is correct to the given question.\n"
|
||||
"What letter did the following answer choose? If the answer did not select a letter answer choice, "
|
||||
"first try to infer the answer based off the given choices.\n"
|
||||
"If it does not seem like the given answer corresponds to an answer choice OR if there is no selected answer, please just respond with Z.\n"
|
||||
"Make sure you answer with ONLY the letters chosen.\n"
|
||||
'Example 1: \n'
|
||||
'Question: <start>\nWhat is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n<end>\n'
|
||||
'Answer: <start>\na cute teddy bear\n<end>\nYour output: A\n'
|
||||
'Example 2: \n'
|
||||
'Question: <start>\nWhat is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n<end>\n'
|
||||
'Answer: <start>\nSpider\n<end>\nYour output: Z\n'
|
||||
'Example 3: \n'
|
||||
'Question: <start>\nWhich figure is a rotation of the object?\n<end>\n'
|
||||
'Answer: <start>\nThe figure on the right, labeled "D," is a rotation of the object shown in the top left corner.\n<end>\nYour output: D\n'
|
||||
'Example 4: \n'
|
||||
'Question: <start>\nWhich of the boxes comes next in the sequence? Select from A-E\n<end>\n'
|
||||
'Answer: <start>\nThe sequence of the boxes is A, B, C, D, E.\n<end>\nYour output: ABCDE\n'
|
||||
'Example 5: \n'
|
||||
'Question: <start>\n{}\n<end>\nAnswer: <start>\n{}\n<end>\nYour output: '
|
||||
)
|
||||
|
||||
return tmpl.format(question, prediction)
|
||||
|
||||
|
||||
def LogicVista_auxeval(model, line):
|
||||
prompt = build_prompt_logicvista(line)
|
||||
print(prompt)
|
||||
log = ''
|
||||
retry = 5
|
||||
|
||||
for i in range(retry):
|
||||
prediction = line['prediction']
|
||||
res = model.generate(prompt, temperature=i * 0.5)
|
||||
answer = line['answer'].split(", ")
|
||||
for j in range(0, len(answer)):
|
||||
answer[j] = answer[j].lower()
|
||||
answer.sort()
|
||||
answer = ''.join(answer)
|
||||
|
||||
if FAIL_MSG in res:
|
||||
log += f'Try {i}: output is {prediction}, failed to parse.\n'
|
||||
elif not res.isupper() or not res.isalpha():
|
||||
log += f'Try {i}: output is {prediction}, failed to parse.\n'
|
||||
else:
|
||||
log += 'Succeed'
|
||||
hit = 0
|
||||
extracted = [alpha.lower() for alpha in res]
|
||||
extracted.sort()
|
||||
extracted = ''.join(extracted)
|
||||
if extracted == answer:
|
||||
hit = 1
|
||||
return dict(log=log, res=res, hit=hit)
|
||||
log += 'All 5 retries failed.\n'
|
||||
return dict(log=log, res='', hit=0)
|
||||
|
||||
|
||||
cat = ["diagram", "ocr", "patterns", "graphs", "tables", "3d shapes", "puzzles", "sequences", "physics"]
|
||||
|
||||
|
||||
def evaluate_logicvista(file_path):
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
tot = defaultdict(lambda: 0)
|
||||
hit = defaultdict(lambda: 0)
|
||||
acc = defaultdict(lambda: 0)
|
||||
|
||||
lt = len(df)
|
||||
skill_list = []
|
||||
|
||||
df_tot = df
|
||||
|
||||
df_inductive = df[df["skill"].str.contains("inductive")]
|
||||
df_deductive = df[df["skill"].str.contains("deductive")]
|
||||
df_numerical = df[df["skill"].str.contains("numerical")]
|
||||
df_spatial = df[df["skill"].str.contains("spatial")]
|
||||
df_mechanical = df[df["skill"].str.contains("mechanical")]
|
||||
|
||||
tot_correct = df_tot["hit"].sum()
|
||||
tot_acc = (tot_correct / df_tot.shape[0]) * 100
|
||||
tot['Overall'] = df_tot.shape[0]
|
||||
hit['Overall'] = tot_correct
|
||||
acc['Overall'] = tot_acc
|
||||
|
||||
inductive_correct = df_inductive["hit"].sum()
|
||||
inductive_acc = (inductive_correct / df_inductive.shape[0]) * 100
|
||||
|
||||
tot["inductive"] = df_inductive.shape[0]
|
||||
hit["inductive"] = inductive_correct
|
||||
acc["inductive"] = inductive_acc
|
||||
|
||||
deductive_correct = df_deductive["hit"].sum()
|
||||
deductive_acc = (deductive_correct / df_deductive.shape[0]) * 100
|
||||
|
||||
tot["deductive"] = df_deductive.shape[0]
|
||||
hit["deductive"] = deductive_correct
|
||||
acc["deductive"] = deductive_acc
|
||||
|
||||
numerical_correct = df_numerical["hit"].sum()
|
||||
numerical_acc = (numerical_correct / df_numerical.shape[0]) * 100
|
||||
|
||||
tot["numerical"] = df_numerical.shape[0]
|
||||
hit["numerical"] = numerical_correct
|
||||
acc["numerical"] = numerical_acc
|
||||
|
||||
spatial_correct = df_spatial["hit"].sum()
|
||||
spatial_acc = (spatial_correct / df_spatial.shape[0]) * 100
|
||||
|
||||
tot["spatial"] = df_spatial.shape[0]
|
||||
hit["spatial"] = spatial_correct
|
||||
acc["spatial"] = spatial_acc
|
||||
|
||||
mechanical_correct = df_mechanical["hit"].sum()
|
||||
mechanical_acc = (mechanical_correct / df_mechanical.shape[0]) * 100
|
||||
|
||||
tot["mechanical"] = df_mechanical.shape[0]
|
||||
hit["mechanical"] = mechanical_correct
|
||||
acc["mechanical"] = mechanical_acc
|
||||
|
||||
# capability dimension, the official data json does not contain 'capability' column, so it is now ignored
|
||||
# for i in cat:
|
||||
# curr = df[df["capability"].str.contains(i.replace(" ", ""))]
|
||||
# correct = curr["hit"].sum()
|
||||
# accuracy = (correct / curr.shape[0]) * 100
|
||||
# tot[i] = curr.shape[0]
|
||||
# hit[i] = correct
|
||||
# acc[i] = accuracy
|
||||
|
||||
res = defaultdict(list)
|
||||
for k in tot.keys():
|
||||
res['Task&Skill'].append(k)
|
||||
res['tot'].append(tot[k])
|
||||
res['hit'].append(hit[k])
|
||||
res['acc'].append(acc[k])
|
||||
res = pd.DataFrame(res)
|
||||
return res
|
||||
80
eval_mm/vlmevalkit/vlmeval/dataset/utils/longvideobench.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from ...smp import *
|
||||
from .multiple_choice import extract_answer_from_item
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
FAIL_MSG = 'Failed to obtain answer via API.'
|
||||
|
||||
DURATIONS = [15, 60, 600, 3600]
|
||||
TASK_CATEGORIES = [
|
||||
"S2E", "S2O", "S2A",
|
||||
"E2O", "O2E", "T2E",
|
||||
"T2O", "T2A", "E3E",
|
||||
"O3O", "SSS", "SOS",
|
||||
"SAA", "T3E", "T3O",
|
||||
"TOS", "TAA"
|
||||
]
|
||||
|
||||
|
||||
def get_dimension_rating(data_path):
|
||||
data = load(data_path)
|
||||
print(data.iloc[0])
|
||||
|
||||
duration_rating = {k: {} for k in DURATIONS}
|
||||
for duration in DURATIONS + ['overall']:
|
||||
duration_rating[duration] = {
|
||||
'overall': '',
|
||||
'question_category': {k: [] for k in TASK_CATEGORIES}
|
||||
}
|
||||
|
||||
for i in range(len(data)):
|
||||
|
||||
task_ctg = data.iloc[i]['question_category']
|
||||
|
||||
duration = data.iloc[i]['duration_group']
|
||||
duration_rating[duration]['question_category'][task_ctg].append(data.iloc[i]['score'])
|
||||
|
||||
duration_rating['overall']['question_category'][task_ctg].append(data.iloc[i]['score'])
|
||||
|
||||
for duration in DURATIONS + ['overall']:
|
||||
overall_res_dur = f'{np.mean([x for x in sum(duration_rating[duration]["question_category"].values(), []) if x >= 0]):.3f}' # noqa: E501
|
||||
duration_rating[duration]['overall'] = overall_res_dur
|
||||
for task_ctg in TASK_CATEGORIES:
|
||||
task_res_dur = f'{np.mean([x for x in duration_rating[duration]["question_category"][task_ctg] if x >= 0]):.3f}' # noqa: E501
|
||||
duration_rating[duration]['question_category'][task_ctg] = task_res_dur
|
||||
|
||||
return duration_rating
|
||||
|
||||
|
||||
def extract_option(model, input_item, dataset_name):
|
||||
options = input_item['question'].split('\n')[1:]
|
||||
for id, option in enumerate(options):
|
||||
option_id = chr(ord('A') + id) + '.'
|
||||
if option.find(option_id) >= 0:
|
||||
input_item[chr(ord('A') + id)] = option[option.find(option_id) + len(option_id):].strip('. \n')
|
||||
return extract_answer_from_item(model, input_item, dataset_name)['opt']
|
||||
|
||||
|
||||
def extract_characters_regex(s):
|
||||
s = s.strip()
|
||||
answer_prefixes = [
|
||||
'The best answer is',
|
||||
'The correct answer is',
|
||||
'The answer is',
|
||||
'The answer',
|
||||
'The best option is'
|
||||
'The correct option is',
|
||||
'Best answer:'
|
||||
'Best option:',
|
||||
'Answer:',
|
||||
'Option:',
|
||||
]
|
||||
for answer_prefix in answer_prefixes:
|
||||
s = s.replace(answer_prefix, '')
|
||||
|
||||
if len(s.split()) > 10 and not re.search('[ABCDE]', s):
|
||||
return ''
|
||||
matches = re.search(r'[ABCDE]', s)
|
||||
if matches is None:
|
||||
return ''
|
||||
return matches[0]
|
||||