Update to MiniCPM-V 2.6

This commit is contained in:
yiranyyu
2024-08-06 12:26:49 +08:00
parent 1cb882d473
commit b1a15299e6
28 changed files with 3692 additions and 191 deletions

1002
README.md

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

BIN
assets/gif_cases/ai.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.5 MiB

BIN
assets/gif_cases/beer.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 MiB

BIN
assets/gif_cases/mb.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 MiB

BIN
assets/gif_cases/rabbit.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 MiB

After

Width:  |  Height:  |  Size: 10 MiB

BIN
assets/gif_cases/wfh.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 MiB

BIN
assets/gif_cases/zoo.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 MiB

BIN
assets/radar_final.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

76
chat.py
View File

@@ -183,13 +183,87 @@ class MiniCPMV2_5:
)
return answer
class MiniCPMV2_6:
def __init__(self, model_path, multi_gpus=False) -> None:
print('torch_version:', torch.__version__)
if multi_gpus: # inference on multi-gpus
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
with init_empty_weights():
model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
device_id = device_map["llm.model.embed_tokens"]
device_map["llm.lm_head"] = device_id # first and last layer of llm should be in the same device
device_map["vpm"] = device_id
device_map["resampler"] = device_id
device_id2 = device_map["llm.model.layers.26"]
device_map["llm.model.layers.8"] = device_id2
device_map["llm.model.layers.9"] = device_id2
device_map["llm.model.layers.10"] = device_id2
device_map["llm.model.layers.11"] = device_id2
device_map["llm.model.layers.12"] = device_id2
device_map["llm.model.layers.13"] = device_id2
device_map["llm.model.layers.14"] = device_id2
device_map["llm.model.layers.15"] = device_id2
device_map["llm.model.layers.16"] = device_id2
print(device_map)
self.model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
self.model.eval()
else:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
self.model.eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
def chat(self, input):
image = None
if "image" in input and len(input["image"]) > 10: # legacy API
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input["question"])
for msg in msgs:
contents = msg.pop('content') # support str or List[Dict]
if isinstance(contents, str):
contents = [contents]
new_cnts = []
for c in contents:
if isinstance(c, dict):
if c['type'] == 'text':
c = c['pairs']
elif c['type'] == 'image':
c = Image.open(io.BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
else:
raise ValueError("content type only support text and image.")
new_cnts.append(c)
msg['content'] = new_cnts
print(f'msgs: {str(msgs)}')
answer = self.model.chat(
image=image,
msgs=msgs,
tokenizer=self.tokenizer,
)
return answer
class MiniCPMVChat:
def __init__(self, model_path) -> None:
def __init__(self, model_path, multi_gpus=False) -> None:
if '12B' in model_path:
self.model = OmniLMM12B(model_path)
elif 'MiniCPM-Llama3-V' in model_path:
self.model = MiniCPMV2_5(model_path)
elif 'MiniCPM-V-2_6' in model_path:
self.model = MiniCPMV2_6(model_path, multi_gpus)
else:
self.model = MiniCPMV(model_path)

30
docs/faqs.md Normal file
View File

@@ -0,0 +1,30 @@
### FAQs
<details>
<summary>Q: How to choose between sampling or beam search for inference </summary>
In various scenarios, the quality of results obtained from beam search and sampling decoding strategies can vary. You can determine your decoding strategy based on the following aspects:
If you have the following needs, consider using sampling decoding:
1. You require faster inference speed.
2. You wish for a streaming generation approach.
3. Your task necessitates some open-ended responses.
If your task is about providing deterministic answers, you might want to experiment with beam search to see if it can achieve better outcomes.
</details>
<details>
<summary>Q: How to ensure that the model generates results of sufficient length</summary>
We've observed that during multi-language inference on MiniCPM-V 2.6, the generation sometimes ends prematurely. You can improve the results by passing a `min_new_tokens` parameter.
```python
res = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer,
min_new_tokens=100
)
```
</details>

View File

@@ -105,7 +105,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
}
def conversation_to_ids(conversation, tokenizer, llm_type=None):
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
@@ -115,6 +115,10 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
input_ids, context, raw_msg = conversation_to_ids_llama3(
conversation, tokenizer
)
elif llm_type == "qwen2":
input_ids, context, raw_msg = conversation_to_ids_qwen2(
conversation, tokenizer
)
else:
input_ids, context, raw_msg = conversation_to_ids_minicpm(
conversation, tokenizer
@@ -125,6 +129,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
for i in range(1, len(ids)):
if context[i] == 0:
target[i - 1] = ids[i]
@@ -133,14 +138,21 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
target[i - 1] = tokenizer.eot_id
else:
target[i - 1] = tokenizer.eos_id
# build image bound
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if new_schema:
start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id)
end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
else:
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
@@ -230,6 +242,46 @@ def conversation_to_ids_llama3(conversation, tokenizer):
return input_ids, context, raw_msg
def conversation_to_ids_qwen2(conversation, tokenizer):
raw_msg = ""
chat = []
context = []
for idx, msg in enumerate(conversation):
role = msg["role"]
message = msg["content"]
assert role in ["user", "assistant"]
if role == "user":
prefix = "user"
else:
prefix = "assistant"
chat.append({"role":prefix, "content":message})
raw_msg += prefix + message
assert set([i['role'] for i in chat]) & set(['assistant'])
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
input_ids = np.array(input_ids)
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
context = np.ones_like(input_ids, dtype=np.int8)
for assistant_idx in assistant_idxs:
if assistant_idx-1 in set(start_idxs):
st = assistant_idx + 1
for end_idx in end_idxs:
if end_idx > st:
context[st: end_idx + 1] = 0
break
input_ids = np.hstack(input_ids)
context = np.hstack(context)
return input_ids, context, raw_msg
def preprocess(
image,
conversation,
@@ -256,8 +308,14 @@ def preprocess(
default_image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
)
new_schema = False
use_image_id = False
if llm_type=='qwen2':
new_schema = True
use_image_id = True
if slice_config:
images = []
image_id_cnt = 0
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
@@ -270,9 +328,11 @@ def preprocess(
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{idx}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums)
tokenizer, best_grid, query_nums, new_schema = new_schema)
images = [transform(i) for i in images]
else:
images = [transform(image)]
@@ -286,7 +346,7 @@ def preprocess(
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
if batch_vision:
tgt_sizes = []
@@ -424,7 +484,7 @@ def split_to_patches(image, grid):
return patches
def get_grid_placeholder(tokenizer, grid, query_num):
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
)
@@ -437,7 +497,10 @@ def get_grid_placeholder(tokenizer, grid, query_num):
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
slice_placeholder = tokenizer.slice_start + \
if new_schema:
slice_placeholder = '\n'.join(slices)
else:
slice_placeholder = tokenizer.slice_start + \
"\n".join(slices) + tokenizer.slice_end
return slice_placeholder
@@ -455,4 +518,4 @@ def reshape_by_patch(image_tensor, patch_size):
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
patches = patches.permute(0, 1, 3, 2).reshape(
image_tensor.size(0), patch_size, -1)
return patches
return patches

View File

@@ -6,6 +6,8 @@ from dataclasses import dataclass, field
from functools import partial
from typing import Dict, List, Optional, Union, Literal, Tuple
from types import MethodType
from torchvision import transforms
import torch
import transformers
from accelerate.utils import DistributedType
@@ -130,6 +132,18 @@ def make_supervised_data_module(
)
def build_transform():
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
return transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
),
]
)
def get_parameter_number(model):
trainable_params, all_param = 0, 0
for param in model.parameters():
@@ -248,10 +262,11 @@ def train():
else:
batch_vision = False
transform_func = build_transform()
data_module = make_supervised_data_module(
tokenizer=tokenizer,
data_args=data_args,
transform=model.transform,
transform=transform_func,
data_collator=data_collator,
slice_config=slice_config,
llm_type=llm_type,

View File

@@ -6,12 +6,15 @@ NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
MODEL="openbmb/MiniCPM-V-2_6"
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data"
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
@@ -28,10 +31,10 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 false \
--bf16_full_eval false \
--fp16 true \
--fp16_full_eval true \
--bf16 true \
--bf16_full_eval true \
--fp16 false \
--fp16_full_eval false \
--do_train \
--do_eval \
--tune_vision true \
@@ -40,8 +43,8 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \
--output_dir output/output_minicpmv2 \
--logging_dir output/output_minicpmv2 \
--output_dir output/output_minicpmv26 \
--logging_dir output/output_minicpmv26 \
--logging_strategy "steps" \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \

View File

@@ -6,13 +6,14 @@ NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data"
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
LLM_TYPE="qwen2"
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
@@ -42,12 +43,12 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \
--output_dir output/output_minicpmv2_lora \
--logging_dir output/output_minicpmv2_lora \
--output_dir output/output__lora \
--logging_dir output/output_lora \
--logging_strategy "steps" \
--per_device_train_batch_size 2 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--save_steps 1000 \

View File

@@ -1,6 +1,76 @@
# MiniCPM-V Finetuning
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-2_6**, **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
### Data preparation
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation.
<details>
<summary>
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"image": 'path/to/image_0.jpg',
"conversations": [
{
'role': 'user',
'content': '<image>\nHow many desserts are on the white plate?'
},
{
'role': 'assistant',
'content': 'There are three desserts on the white plate.'
},
{
'role': 'user',
'content': 'What type of desserts are they?'
},
{
'role': 'assistant',
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.'
},
{
'role': 'user',
'content': 'What is the setting of the image?'},
{
'role': 'assistant',
'content': 'The image is set on a table top with a plate containing the three desserts.'
},
]
},
]
```
</details>
### Full-parameter finetuning
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path and DATA path in the shell scripts.
```shell
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
DATA="path/to/trainging_data" # json file
EVAL_DATA="path/to/test_data" # json file
```
To launch your training, run the following script:
```
sh finetune_ds.sh
```
#### Customizing Hyperparameters
To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset.
# MiniCPM-V Finetuning
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
### Data preparation
@@ -55,10 +125,10 @@ For the vision-language example with image, you are required to provide **\<imag
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.
```shell
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
DATA="path/to/trainging_data" # json file
EVAL_DATA="path/to/test_data" # json file
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
```
To launch your training, run the following script:
@@ -82,7 +152,7 @@ After training, you could load the model with the path to the adapter. We advise
```
from peft import PeftModel
from transformers import AutoModel
model_type="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
model_type= "openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5 , openbmb/MiniCPM-V-2
path_to_adapter="path_to_your_fine_tuned_checkpoint"
model = AutoModel.from_pretrained(

View File

@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
import deepspeed

View File

@@ -29,5 +29,7 @@ uvicorn==0.24.0.post1
sentencepiece==0.1.99
accelerate==0.30.1
socksio==1.0.0
gradio
gradio==4.22.0
gradio_client
http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl
decord

557
web_demo_2.6.py Normal file
View File

@@ -0,0 +1,557 @@
#!/usr/bin/env python
# encoding: utf-8
import torch
import argparse
from transformers import AutoModel, AutoTokenizer
import gradio as gr
from PIL import Image
from decord import VideoReader, cpu
import io
import os
import copy
import requests
import base64
import json
import traceback
import re
import modelscope_studio as mgr
# README, How to run demo on different devices
# For Nvidia GPUs.
# python web_demo_2.6.py --device cuda
# For Mac with MPS (Apple silicon or AMD GPUs).
# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.6.py --device mps
# Argparser
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
parser.add_argument('--multi-gpus', action='store_true', default=False, help='use multi-gpus')
args = parser.parse_args()
device = args.device
assert device in ['cuda', 'mps']
# Load model
model_path = 'openbmb/MiniCPM-V-2_6'
if 'int4' in model_path:
if device == 'mps':
print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
exit()
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
else:
if args.multi_gpus:
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
with init_empty_weights():
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
device_id = device_map["llm.model.embed_tokens"]
device_map["llm.lm_head"] = device_id # firtt and last layer should be in same device
device_map["vpm"] = device_id
device_map["resampler"] = device_id
device_id2 = device_map["llm.model.layers.26"]
device_map["llm.model.layers.8"] = device_id2
device_map["llm.model.layers.9"] = device_id2
device_map["llm.model.layers.10"] = device_id2
device_map["llm.model.layers.11"] = device_id2
device_map["llm.model.layers.12"] = device_id2
device_map["llm.model.layers.13"] = device_id2
device_map["llm.model.layers.14"] = device_id2
device_map["llm.model.layers.15"] = device_id2
device_map["llm.model.layers.16"] = device_id2
#print(device_map)
model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
else:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to(device=device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
ERROR_MSG = "Error, please retry"
model_name = 'MiniCPM-V 2.6'
MAX_NUM_FRAMES = 64
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
def get_file_extension(filename):
return os.path.splitext(filename)[1].lower()
def is_image(filename):
return get_file_extension(filename) in IMAGE_EXTENSIONS
def is_video(filename):
return get_file_extension(filename) in VIDEO_EXTENSIONS
form_radio = {
'choices': ['Beam Search', 'Sampling'],
#'value': 'Beam Search',
'value': 'Sampling',
'interactive': True,
'label': 'Decode Type'
}
def create_component(params, comp='Slider'):
if comp == 'Slider':
return gr.Slider(
minimum=params['minimum'],
maximum=params['maximum'],
value=params['value'],
step=params['step'],
interactive=params['interactive'],
label=params['label']
)
elif comp == 'Radio':
return gr.Radio(
choices=params['choices'],
value=params['value'],
interactive=params['interactive'],
label=params['label']
)
elif comp == 'Button':
return gr.Button(
value=params['value'],
interactive=True
)
def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
return mgr.MultimodalInput(upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
submit_button_props={'label': 'Submit'})
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
try:
print('msgs:', msgs)
answer = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer,
**params
)
res = re.sub(r'(<box>.*</box>)', '', answer)
res = res.replace('<ref>', '')
res = res.replace('</ref>', '')
res = res.replace('<box>', '')
answer = res.replace('</box>', '')
print('answer:', answer)
return 0, answer, None, None
except Exception as e:
print(e)
traceback.print_exc()
return -1, ERROR_MSG, None, None
def encode_image(image):
if not isinstance(image, Image.Image):
if hasattr(image, 'path'):
image = Image.open(image.path).convert("RGB")
else:
image = Image.open(image.file.path).convert("RGB")
# resize to max_size
max_size = 448*16
if max(image.size) > max_size:
w,h = image.size
if w > h:
new_w = max_size
new_h = int(h * max_size / w)
else:
new_h = max_size
new_w = int(w * max_size / h)
image = image.resize((new_w, new_h), resample=Image.BICUBIC)
return image
## save by BytesIO and convert to base64
#buffered = io.BytesIO()
#image.save(buffered, format="png")
#im_b64 = base64.b64encode(buffered.getvalue()).decode()
#return {"type": "image", "pairs": im_b64}
def encode_video(video):
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
if hasattr(video, 'path'):
vr = VideoReader(video.path, ctx=cpu(0))
else:
vr = VideoReader(video.file.path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx)>MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
video = vr.get_batch(frame_idx).asnumpy()
video = [Image.fromarray(v.astype('uint8')) for v in video]
video = [encode_image(v) for v in video]
print('video frames:', len(video))
return video
def check_mm_type(mm_file):
if hasattr(mm_file, 'path'):
path = mm_file.path
else:
path = mm_file.file.path
if is_image(path):
return "image"
if is_video(path):
return "video"
return None
def encode_mm_file(mm_file):
if check_mm_type(mm_file) == 'image':
return [encode_image(mm_file)]
if check_mm_type(mm_file) == 'video':
return encode_video(mm_file)
return None
def make_text(text):
#return {"type": "text", "pairs": text} # # For remote call
return text
def encode_message(_question):
files = _question.files
question = _question.text
pattern = r"\[mm_media\]\d+\[/mm_media\]"
matches = re.split(pattern, question)
message = []
if len(matches) != len(files) + 1:
gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!")
assert len(matches) == len(files) + 1
text = matches[0].strip()
if text:
message.append(make_text(text))
for i in range(len(files)):
message += encode_mm_file(files[i])
text = matches[i + 1].strip()
if text:
message.append(make_text(text))
return message
def check_has_videos(_question):
images_cnt = 0
videos_cnt = 0
for file in _question.files:
if check_mm_type(file) == "image":
images_cnt += 1
else:
videos_cnt += 1
return images_cnt, videos_cnt
def count_video_frames(_context):
num_frames = 0
for message in _context:
for item in message["content"]:
#if item["type"] == "image": # For remote call
if isinstance(item, Image.Image):
num_frames += 1
return num_frames
def respond(_question, _chat_bot, _app_cfg, params_form):
_context = _app_cfg['ctx'].copy()
_context.append({'role': 'user', 'content': encode_message(_question)})
images_cnt = _app_cfg['images_cnt']
videos_cnt = _app_cfg['videos_cnt']
files_cnts = check_has_videos(_question)
if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0):
gr.Warning("Only supports single video file input right now!")
return _question, _chat_bot, _app_cfg
if params_form == 'Beam Search':
params = {
'sampling': False,
'num_beams': 3,
'repetition_penalty': 1.2,
"max_new_tokens": 2048
}
else:
params = {
'sampling': True,
'top_p': 0.8,
'top_k': 100,
'temperature': 0.7,
'repetition_penalty': 1.05,
"max_new_tokens": 2048
}
if files_cnts[1] + videos_cnt > 0:
params["max_inp_length"] = 4352 # 4096+256
params["use_image_id"] = False
params["max_slice_nums"] = 1 if count_video_frames(_context) > 16 else 2
code, _answer, _, sts = chat("", _context, None, params)
images_cnt += files_cnts[0]
videos_cnt += files_cnts[1]
_context.append({"role": "assistant", "content": [make_text(_answer)]})
_chat_bot.append((_question, _answer))
if code == 0:
_app_cfg['ctx']=_context
_app_cfg['sts']=sts
_app_cfg['images_cnt'] = images_cnt
_app_cfg['videos_cnt'] = videos_cnt
upload_image_disabled = videos_cnt > 0
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
return create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg
def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg):
ctx = _app_cfg["ctx"]
message_item = []
if _image is not None:
image = Image.open(_image).convert("RGB")
ctx.append({"role": "user", "content": [encode_image(image), make_text(_user_message)]})
message_item.append({"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]})
else:
if _user_message:
ctx.append({"role": "user", "content": [make_text(_user_message)]})
message_item.append({"text": _user_message, "files": []})
else:
message_item.append(None)
if _assistant_message:
ctx.append({"role": "assistant", "content": [make_text(_assistant_message)]})
message_item.append({"text": _assistant_message, "files": []})
else:
message_item.append(None)
_chat_bot.append(message_item)
return None, "", "", _chat_bot, _app_cfg
def fewshot_respond(_image, _user_message, _chat_bot, _app_cfg, params_form):
user_message_contents = []
_context = _app_cfg["ctx"].copy()
if _image:
image = Image.open(_image).convert("RGB")
user_message_contents += [encode_image(image)]
if _user_message:
user_message_contents += [make_text(_user_message)]
if user_message_contents:
_context.append({"role": "user", "content": user_message_contents})
if params_form == 'Beam Search':
params = {
'sampling': False,
'num_beams': 3,
'repetition_penalty': 1.2,
"max_new_tokens": 2048
}
else:
params = {
'sampling': True,
'top_p': 0.8,
'top_k': 100,
'temperature': 0.7,
'repetition_penalty': 1.05,
"max_new_tokens": 2048
}
code, _answer, _, sts = chat("", _context, None, params)
_context.append({"role": "assistant", "content": [make_text(_answer)]})
if _image:
_chat_bot.append([
{"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]},
{"text": _answer, "files": []}
])
else:
_chat_bot.append([
{"text": _user_message, "files": [_image]},
{"text": _answer, "files": []}
])
if code == 0:
_app_cfg['ctx']=_context
_app_cfg['sts']=sts
return None, '', '', _chat_bot, _app_cfg
def regenerate_button_clicked(_question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg, params_form):
if len(_chat_bot) <= 1 or not _chat_bot[-1][1]:
gr.Warning('No question for regeneration.')
return '', _image, _user_message, _assistant_message, _chat_bot, _app_cfg
if _app_cfg["chat_type"] == "Chat":
images_cnt = _app_cfg['images_cnt']
videos_cnt = _app_cfg['videos_cnt']
_question = _chat_bot[-1][0]
_chat_bot = _chat_bot[:-1]
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
files_cnts = check_has_videos(_question)
images_cnt -= files_cnts[0]
videos_cnt -= files_cnts[1]
_app_cfg['images_cnt'] = images_cnt
_app_cfg['videos_cnt'] = videos_cnt
upload_image_disabled = videos_cnt > 0
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
_question, _chat_bot, _app_cfg = respond(_question, _chat_bot, _app_cfg, params_form)
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
else:
last_message = _chat_bot[-1][0]
last_image = None
last_user_message = ''
if last_message.text:
last_user_message = last_message.text
if last_message.files:
last_image = last_message.files[0].file.path
_chat_bot = _chat_bot[:-1]
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
_image, _user_message, _assistant_message, _chat_bot, _app_cfg = fewshot_respond(last_image, last_user_message, _chat_bot, _app_cfg, params_form)
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
def flushed():
return gr.update(interactive=True)
def clear(txt_message, chat_bot, app_session):
txt_message.files.clear()
txt_message.text = ''
chat_bot = copy.deepcopy(init_conversation)
app_session['sts'] = None
app_session['ctx'] = []
app_session['images_cnt'] = 0
app_session['videos_cnt'] = 0
return create_multimodal_input(), chat_bot, app_session, None, '', ''
def select_chat_type(_tab, _app_cfg):
_app_cfg["chat_type"] = _tab
return _app_cfg
init_conversation = [
[
None,
{
# The first message of bot closes the typewriter.
"text": "You can talk to me now",
"flushing": False
}
],
]
css = """
video { height: auto !important; }
.example label { font-size: 16px;}
"""
introduction = """
## Features:
1. Chat with single image
2. Chat with multiple images
3. Chat with video
4. In-context few-shot learning
Click `How to use` tab to see examples.
"""
with gr.Blocks(css=css) as demo:
with gr.Tab(model_name):
with gr.Row():
with gr.Column(scale=1, min_width=300):
gr.Markdown(value=introduction)
params_form = create_component(form_radio, comp='Radio')
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
clear_button = create_component({'value': 'Clear History'}, comp='Button')
with gr.Column(scale=3, min_width=500):
app_session = gr.State({'sts':None,'ctx':[], 'images_cnt': 0, 'videos_cnt': 0, 'chat_type': 'Chat'})
chat_bot = mgr.Chatbot(label=f"Chat with {model_name}", value=copy.deepcopy(init_conversation), height=600, flushing=False, bubble_full_width=False)
with gr.Tab("Chat") as chat_tab:
txt_message = create_multimodal_input()
chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False)
txt_message.submit(
respond,
[txt_message, chat_bot, app_session, params_form],
[txt_message, chat_bot, app_session]
)
with gr.Tab("Few Shot") as fewshot_tab:
fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", sources=["upload"])
with gr.Column(scale=3):
user_message = gr.Textbox(label="User")
assistant_message = gr.Textbox(label="Assistant")
with gr.Row():
add_demonstration_button = gr.Button("Add Example")
generate_button = gr.Button(value="Generate", variant="primary")
add_demonstration_button.click(
fewshot_add_demonstration,
[image_input, user_message, assistant_message, chat_bot, app_session],
[image_input, user_message, assistant_message, chat_bot, app_session]
)
generate_button.click(
fewshot_respond,
[image_input, user_message, chat_bot, app_session, params_form],
[image_input, user_message, assistant_message, chat_bot, app_session]
)
chat_tab.select(
select_chat_type,
[chat_tab_label, app_session],
[app_session]
)
chat_tab.select( # do clear
clear,
[txt_message, chat_bot, app_session],
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
)
fewshot_tab.select(
select_chat_type,
[fewshot_tab_label, app_session],
[app_session]
)
fewshot_tab.select( # do clear
clear,
[txt_message, chat_bot, app_session],
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
)
chat_bot.flushed(
flushed,
outputs=[txt_message]
)
regenerate.click(
regenerate_button_clicked,
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session, params_form],
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session]
)
clear_button.click(
clear,
[txt_message, chat_bot, app_session],
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
)
with gr.Tab("How to use"):
with gr.Column():
with gr.Row():
image_example = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/m_bear2.gif", label='1. Chat with single or multiple images', interactive=False, width=400, elem_classes="example")
example2 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/video2.gif", label='2. Chat with video', interactive=False, width=400, elem_classes="example")
example3 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/fshot.gif", label='3. Few shot', interactive=False, width=400, elem_classes="example")
# launch
demo.launch(share=False, debug=True, show_api=False, server_port=8885, server_name="0.0.0.0")