Update to MiniCPM-V 2.6
1002
README_en.md
997
README_zh.md
BIN
assets/gif_cases/ai.gif
Normal file
|
After Width: | Height: | Size: 5.5 MiB |
BIN
assets/gif_cases/beer.gif
Normal file
|
After Width: | Height: | Size: 4.4 MiB |
BIN
assets/gif_cases/mb.gif
Normal file
|
After Width: | Height: | Size: 13 MiB |
BIN
assets/gif_cases/rabbit.gif
Normal file
|
After Width: | Height: | Size: 7.7 MiB |
|
Before Width: | Height: | Size: 14 MiB After Width: | Height: | Size: 10 MiB |
BIN
assets/gif_cases/wfh.gif
Normal file
|
After Width: | Height: | Size: 7.8 MiB |
BIN
assets/gif_cases/zoo.gif
Normal file
|
After Width: | Height: | Size: 3.8 MiB |
BIN
assets/minicpmv2_6/ICL-Mem.png
Normal file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
assets/minicpmv2_6/ICL-elec.png
Normal file
|
After Width: | Height: | Size: 4.4 MiB |
BIN
assets/minicpmv2_6/multi_img-bike.png
Normal file
|
After Width: | Height: | Size: 8.0 MiB |
BIN
assets/minicpmv2_6/multi_img-code.png
Normal file
|
After Width: | Height: | Size: 4.5 MiB |
BIN
assets/minicpmv2_6/multi_img-menu.png
Normal file
|
After Width: | Height: | Size: 3.7 MiB |
BIN
assets/minicpmv2_6/multiling-medal.png
Normal file
|
After Width: | Height: | Size: 2.1 MiB |
BIN
assets/minicpmv2_6/multiling-olympic.png
Normal file
|
After Width: | Height: | Size: 2.6 MiB |
BIN
assets/radar_final.png
Normal file
|
After Width: | Height: | Size: 1.1 MiB |
76
chat.py
@@ -183,13 +183,87 @@ class MiniCPMV2_5:
|
||||
)
|
||||
return answer
|
||||
|
||||
class MiniCPMV2_6:
|
||||
def __init__(self, model_path, multi_gpus=False) -> None:
|
||||
|
||||
print('torch_version:', torch.__version__)
|
||||
if multi_gpus: # inference on multi-gpus
|
||||
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
|
||||
with init_empty_weights():
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
|
||||
|
||||
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
|
||||
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
|
||||
device_id = device_map["llm.model.embed_tokens"]
|
||||
device_map["llm.lm_head"] = device_id # first and last layer of llm should be in the same device
|
||||
device_map["vpm"] = device_id
|
||||
device_map["resampler"] = device_id
|
||||
device_id2 = device_map["llm.model.layers.26"]
|
||||
device_map["llm.model.layers.8"] = device_id2
|
||||
device_map["llm.model.layers.9"] = device_id2
|
||||
device_map["llm.model.layers.10"] = device_id2
|
||||
device_map["llm.model.layers.11"] = device_id2
|
||||
device_map["llm.model.layers.12"] = device_id2
|
||||
device_map["llm.model.layers.13"] = device_id2
|
||||
device_map["llm.model.layers.14"] = device_id2
|
||||
device_map["llm.model.layers.15"] = device_id2
|
||||
device_map["llm.model.layers.16"] = device_id2
|
||||
print(device_map)
|
||||
|
||||
self.model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
|
||||
self.model.eval()
|
||||
else:
|
||||
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
|
||||
self.model.eval().cuda()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
def chat(self, input):
|
||||
image = None
|
||||
if "image" in input and len(input["image"]) > 10: # legacy API
|
||||
try:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
|
||||
except Exception as e:
|
||||
return "Image decode error"
|
||||
|
||||
msgs = json.loads(input["question"])
|
||||
|
||||
for msg in msgs:
|
||||
contents = msg.pop('content') # support str or List[Dict]
|
||||
if isinstance(contents, str):
|
||||
contents = [contents]
|
||||
|
||||
new_cnts = []
|
||||
for c in contents:
|
||||
if isinstance(c, dict):
|
||||
if c['type'] == 'text':
|
||||
c = c['pairs']
|
||||
elif c['type'] == 'image':
|
||||
c = Image.open(io.BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
|
||||
else:
|
||||
raise ValueError("content type only support text and image.")
|
||||
new_cnts.append(c)
|
||||
msg['content'] = new_cnts
|
||||
print(f'msgs: {str(msgs)}')
|
||||
|
||||
answer = self.model.chat(
|
||||
image=image,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
return answer
|
||||
|
||||
|
||||
class MiniCPMVChat:
|
||||
def __init__(self, model_path) -> None:
|
||||
def __init__(self, model_path, multi_gpus=False) -> None:
|
||||
if '12B' in model_path:
|
||||
self.model = OmniLMM12B(model_path)
|
||||
elif 'MiniCPM-Llama3-V' in model_path:
|
||||
self.model = MiniCPMV2_5(model_path)
|
||||
elif 'MiniCPM-V-2_6' in model_path:
|
||||
self.model = MiniCPMV2_6(model_path, multi_gpus)
|
||||
else:
|
||||
self.model = MiniCPMV(model_path)
|
||||
|
||||
|
||||
30
docs/faqs.md
Normal file
@@ -0,0 +1,30 @@
|
||||
### FAQs
|
||||
|
||||
<details>
|
||||
<summary>Q: How to choose between sampling or beam search for inference </summary>
|
||||
|
||||
In various scenarios, the quality of results obtained from beam search and sampling decoding strategies can vary. You can determine your decoding strategy based on the following aspects:
|
||||
|
||||
If you have the following needs, consider using sampling decoding:
|
||||
|
||||
1. You require faster inference speed.
|
||||
2. You wish for a streaming generation approach.
|
||||
3. Your task necessitates some open-ended responses.
|
||||
|
||||
If your task is about providing deterministic answers, you might want to experiment with beam search to see if it can achieve better outcomes.
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Q: How to ensure that the model generates results of sufficient length</summary>
|
||||
|
||||
We've observed that during multi-language inference on MiniCPM-V 2.6, the generation sometimes ends prematurely. You can improve the results by passing a `min_new_tokens` parameter.
|
||||
```python
|
||||
res = model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer,
|
||||
min_new_tokens=100
|
||||
)
|
||||
```
|
||||
</details>
|
||||
@@ -105,7 +105,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
|
||||
}
|
||||
|
||||
|
||||
def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
|
||||
"""
|
||||
for single image multi-turn conversation
|
||||
conversation: [{'role': 'user', 'content': 'Describe this image'},
|
||||
@@ -115,6 +115,10 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
input_ids, context, raw_msg = conversation_to_ids_llama3(
|
||||
conversation, tokenizer
|
||||
)
|
||||
elif llm_type == "qwen2":
|
||||
input_ids, context, raw_msg = conversation_to_ids_qwen2(
|
||||
conversation, tokenizer
|
||||
)
|
||||
else:
|
||||
input_ids, context, raw_msg = conversation_to_ids_minicpm(
|
||||
conversation, tokenizer
|
||||
@@ -125,6 +129,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
|
||||
# build target
|
||||
target = torch.full_like(ids, -100, dtype=torch.int32)
|
||||
|
||||
for i in range(1, len(ids)):
|
||||
if context[i] == 0:
|
||||
target[i - 1] = ids[i]
|
||||
@@ -135,6 +140,13 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
target[i - 1] = tokenizer.eos_id
|
||||
|
||||
# build image bound
|
||||
if new_schema:
|
||||
start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id)
|
||||
end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id)
|
||||
image_start_tokens = torch.where(start_cond)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
else:
|
||||
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||
@@ -230,6 +242,46 @@ def conversation_to_ids_llama3(conversation, tokenizer):
|
||||
return input_ids, context, raw_msg
|
||||
|
||||
|
||||
def conversation_to_ids_qwen2(conversation, tokenizer):
|
||||
raw_msg = ""
|
||||
chat = []
|
||||
context = []
|
||||
for idx, msg in enumerate(conversation):
|
||||
role = msg["role"]
|
||||
message = msg["content"]
|
||||
assert role in ["user", "assistant"]
|
||||
if role == "user":
|
||||
prefix = "user"
|
||||
else:
|
||||
prefix = "assistant"
|
||||
chat.append({"role":prefix, "content":message})
|
||||
raw_msg += prefix + message
|
||||
assert set([i['role'] for i in chat]) & set(['assistant'])
|
||||
|
||||
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
|
||||
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
|
||||
input_ids = np.array(input_ids)
|
||||
|
||||
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
|
||||
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
|
||||
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
|
||||
|
||||
context = np.ones_like(input_ids, dtype=np.int8)
|
||||
|
||||
for assistant_idx in assistant_idxs:
|
||||
if assistant_idx-1 in set(start_idxs):
|
||||
st = assistant_idx + 1
|
||||
for end_idx in end_idxs:
|
||||
if end_idx > st:
|
||||
context[st: end_idx + 1] = 0
|
||||
break
|
||||
|
||||
input_ids = np.hstack(input_ids)
|
||||
context = np.hstack(context)
|
||||
return input_ids, context, raw_msg
|
||||
|
||||
|
||||
|
||||
def preprocess(
|
||||
image,
|
||||
conversation,
|
||||
@@ -256,8 +308,14 @@ def preprocess(
|
||||
default_image_placeholder = (
|
||||
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
|
||||
)
|
||||
new_schema = False
|
||||
use_image_id = False
|
||||
if llm_type=='qwen2':
|
||||
new_schema = True
|
||||
use_image_id = True
|
||||
if slice_config:
|
||||
images = []
|
||||
image_id_cnt = 0
|
||||
source_image, patches, best_grid = slice_image(
|
||||
image,
|
||||
slice_config["max_slice_nums"],
|
||||
@@ -270,9 +328,11 @@ def preprocess(
|
||||
for i in range(len(patches)):
|
||||
for j in range(len(patches[0])):
|
||||
images.append(patches[i][j])
|
||||
|
||||
if use_image_id:
|
||||
image_placeholder = f'{tokenizer.im_id_start}{idx}{tokenizer.im_id_end}' + image_placeholder
|
||||
image_id_cnt += 1
|
||||
image_placeholder += get_grid_placeholder(
|
||||
tokenizer, best_grid, query_nums)
|
||||
tokenizer, best_grid, query_nums, new_schema = new_schema)
|
||||
images = [transform(i) for i in images]
|
||||
else:
|
||||
images = [transform(image)]
|
||||
@@ -286,7 +346,7 @@ def preprocess(
|
||||
image_placeholder + "\n" + conversation[0]["content"]
|
||||
)
|
||||
|
||||
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
|
||||
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
|
||||
|
||||
if batch_vision:
|
||||
tgt_sizes = []
|
||||
@@ -424,7 +484,7 @@ def split_to_patches(image, grid):
|
||||
return patches
|
||||
|
||||
|
||||
def get_grid_placeholder(tokenizer, grid, query_num):
|
||||
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
|
||||
image_placeholder = (
|
||||
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
|
||||
)
|
||||
@@ -437,6 +497,9 @@ def get_grid_placeholder(tokenizer, grid, query_num):
|
||||
for j in range(cols):
|
||||
lines.append(image_placeholder)
|
||||
slices.append("".join(lines))
|
||||
if new_schema:
|
||||
slice_placeholder = '\n'.join(slices)
|
||||
else:
|
||||
slice_placeholder = tokenizer.slice_start + \
|
||||
"\n".join(slices) + tokenizer.slice_end
|
||||
return slice_placeholder
|
||||
|
||||
@@ -6,6 +6,8 @@ from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union, Literal, Tuple
|
||||
from types import MethodType
|
||||
from torchvision import transforms
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate.utils import DistributedType
|
||||
@@ -130,6 +132,18 @@ def make_supervised_data_module(
|
||||
)
|
||||
|
||||
|
||||
def build_transform():
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def get_parameter_number(model):
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
@@ -248,10 +262,11 @@ def train():
|
||||
else:
|
||||
batch_vision = False
|
||||
|
||||
transform_func = build_transform()
|
||||
data_module = make_supervised_data_module(
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
transform=model.transform,
|
||||
transform=transform_func,
|
||||
data_collator=data_collator,
|
||||
slice_config=slice_config,
|
||||
llm_type=llm_type,
|
||||
|
||||
@@ -6,12 +6,15 @@ NODE_RANK=0
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
|
||||
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
||||
MODEL="openbmb/MiniCPM-V-2_6"
|
||||
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
|
||||
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
||||
# See the section for finetuning in README for more information.
|
||||
DATA="path/to/trainging_data"
|
||||
EVAL_DATA="path/to/test_data"
|
||||
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
|
||||
|
||||
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $GPUS_PER_NODE \
|
||||
@@ -28,10 +31,10 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--remove_unused_columns false \
|
||||
--label_names "labels" \
|
||||
--prediction_loss_only false \
|
||||
--bf16 false \
|
||||
--bf16_full_eval false \
|
||||
--fp16 true \
|
||||
--fp16_full_eval true \
|
||||
--bf16 true \
|
||||
--bf16_full_eval true \
|
||||
--fp16 false \
|
||||
--fp16_full_eval false \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--tune_vision true \
|
||||
@@ -40,8 +43,8 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--max_slice_nums 9 \
|
||||
--max_steps 10000 \
|
||||
--eval_steps 1000 \
|
||||
--output_dir output/output_minicpmv2 \
|
||||
--logging_dir output/output_minicpmv2 \
|
||||
--output_dir output/output_minicpmv26 \
|
||||
--logging_dir output/output_minicpmv26 \
|
||||
--logging_strategy "steps" \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
|
||||
@@ -6,13 +6,14 @@ NODE_RANK=0
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
|
||||
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
||||
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
|
||||
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
||||
# See the section for finetuning in README for more information.
|
||||
DATA="path/to/trainging_data"
|
||||
EVAL_DATA="path/to/test_data"
|
||||
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||
|
||||
LLM_TYPE="qwen2"
|
||||
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $GPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
@@ -42,12 +43,12 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
||||
--max_slice_nums 9 \
|
||||
--max_steps 10000 \
|
||||
--eval_steps 1000 \
|
||||
--output_dir output/output_minicpmv2_lora \
|
||||
--logging_dir output/output_minicpmv2_lora \
|
||||
--output_dir output/output__lora \
|
||||
--logging_dir output/output_lora \
|
||||
--logging_strategy "steps" \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--evaluation_strategy "steps" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 1000 \
|
||||
|
||||
@@ -1,6 +1,76 @@
|
||||
# MiniCPM-V Finetuning
|
||||
|
||||
|
||||
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-2_6**, **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
|
||||
|
||||
### Data preparation
|
||||
|
||||
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
|
||||
|
||||
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation.
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b>
|
||||
</summary>
|
||||
|
||||
```
|
||||
[
|
||||
{
|
||||
"id": "0",
|
||||
"image": 'path/to/image_0.jpg',
|
||||
"conversations": [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': '<image>\nHow many desserts are on the white plate?'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'There are three desserts on the white plate.'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'What type of desserts are they?'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'What is the setting of the image?'},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'The image is set on a table top with a plate containing the three desserts.'
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Full-parameter finetuning
|
||||
|
||||
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path and DATA path in the shell scripts.
|
||||
|
||||
```shell
|
||||
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
|
||||
DATA="path/to/trainging_data" # json file
|
||||
EVAL_DATA="path/to/test_data" # json file
|
||||
```
|
||||
|
||||
To launch your training, run the following script:
|
||||
|
||||
```
|
||||
sh finetune_ds.sh
|
||||
```
|
||||
|
||||
#### Customizing Hyperparameters
|
||||
To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset.
|
||||
# MiniCPM-V Finetuning
|
||||
|
||||
|
||||
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
|
||||
|
||||
### Data preparation
|
||||
@@ -55,10 +125,10 @@ For the vision-language example with image, you are required to provide **\<imag
|
||||
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.
|
||||
|
||||
```shell
|
||||
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
||||
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
|
||||
DATA="path/to/trainging_data" # json file
|
||||
EVAL_DATA="path/to/test_data" # json file
|
||||
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
|
||||
```
|
||||
|
||||
To launch your training, run the following script:
|
||||
@@ -82,7 +152,7 @@ After training, you could load the model with the path to the adapter. We advise
|
||||
```
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModel
|
||||
model_type="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
||||
model_type= "openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5 , openbmb/MiniCPM-V-2
|
||||
path_to_adapter="path_to_your_fine_tuned_checkpoint"
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import deepspeed
|
||||
|
||||
@@ -29,5 +29,7 @@ uvicorn==0.24.0.post1
|
||||
sentencepiece==0.1.99
|
||||
accelerate==0.30.1
|
||||
socksio==1.0.0
|
||||
gradio
|
||||
gradio==4.22.0
|
||||
gradio_client
|
||||
http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl
|
||||
decord
|
||||
|
||||
557
web_demo_2.6.py
Normal file
@@ -0,0 +1,557 @@
|
||||
#!/usr/bin/env python
|
||||
# encoding: utf-8
|
||||
import torch
|
||||
import argparse
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from decord import VideoReader, cpu
|
||||
import io
|
||||
import os
|
||||
import copy
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
import traceback
|
||||
import re
|
||||
import modelscope_studio as mgr
|
||||
|
||||
|
||||
# README, How to run demo on different devices
|
||||
|
||||
# For Nvidia GPUs.
|
||||
# python web_demo_2.6.py --device cuda
|
||||
|
||||
# For Mac with MPS (Apple silicon or AMD GPUs).
|
||||
# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.6.py --device mps
|
||||
|
||||
# Argparser
|
||||
parser = argparse.ArgumentParser(description='demo')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
|
||||
parser.add_argument('--multi-gpus', action='store_true', default=False, help='use multi-gpus')
|
||||
args = parser.parse_args()
|
||||
device = args.device
|
||||
assert device in ['cuda', 'mps']
|
||||
|
||||
# Load model
|
||||
model_path = 'openbmb/MiniCPM-V-2_6'
|
||||
if 'int4' in model_path:
|
||||
if device == 'mps':
|
||||
print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
|
||||
exit()
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||
else:
|
||||
if args.multi_gpus:
|
||||
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
|
||||
with init_empty_weights():
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
|
||||
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
|
||||
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
|
||||
device_id = device_map["llm.model.embed_tokens"]
|
||||
device_map["llm.lm_head"] = device_id # firtt and last layer should be in same device
|
||||
device_map["vpm"] = device_id
|
||||
device_map["resampler"] = device_id
|
||||
device_id2 = device_map["llm.model.layers.26"]
|
||||
device_map["llm.model.layers.8"] = device_id2
|
||||
device_map["llm.model.layers.9"] = device_id2
|
||||
device_map["llm.model.layers.10"] = device_id2
|
||||
device_map["llm.model.layers.11"] = device_id2
|
||||
device_map["llm.model.layers.12"] = device_id2
|
||||
device_map["llm.model.layers.13"] = device_id2
|
||||
device_map["llm.model.layers.14"] = device_id2
|
||||
device_map["llm.model.layers.15"] = device_id2
|
||||
device_map["llm.model.layers.16"] = device_id2
|
||||
#print(device_map)
|
||||
|
||||
model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
|
||||
model = model.to(device=device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model.eval()
|
||||
|
||||
|
||||
|
||||
|
||||
ERROR_MSG = "Error, please retry"
|
||||
model_name = 'MiniCPM-V 2.6'
|
||||
MAX_NUM_FRAMES = 64
|
||||
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
||||
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
|
||||
def get_file_extension(filename):
|
||||
return os.path.splitext(filename)[1].lower()
|
||||
|
||||
def is_image(filename):
|
||||
return get_file_extension(filename) in IMAGE_EXTENSIONS
|
||||
|
||||
def is_video(filename):
|
||||
return get_file_extension(filename) in VIDEO_EXTENSIONS
|
||||
|
||||
|
||||
form_radio = {
|
||||
'choices': ['Beam Search', 'Sampling'],
|
||||
#'value': 'Beam Search',
|
||||
'value': 'Sampling',
|
||||
'interactive': True,
|
||||
'label': 'Decode Type'
|
||||
}
|
||||
|
||||
|
||||
def create_component(params, comp='Slider'):
|
||||
if comp == 'Slider':
|
||||
return gr.Slider(
|
||||
minimum=params['minimum'],
|
||||
maximum=params['maximum'],
|
||||
value=params['value'],
|
||||
step=params['step'],
|
||||
interactive=params['interactive'],
|
||||
label=params['label']
|
||||
)
|
||||
elif comp == 'Radio':
|
||||
return gr.Radio(
|
||||
choices=params['choices'],
|
||||
value=params['value'],
|
||||
interactive=params['interactive'],
|
||||
label=params['label']
|
||||
)
|
||||
elif comp == 'Button':
|
||||
return gr.Button(
|
||||
value=params['value'],
|
||||
interactive=True
|
||||
)
|
||||
|
||||
|
||||
def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
|
||||
return mgr.MultimodalInput(upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
|
||||
upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
|
||||
submit_button_props={'label': 'Submit'})
|
||||
|
||||
|
||||
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
|
||||
try:
|
||||
print('msgs:', msgs)
|
||||
answer = model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer,
|
||||
**params
|
||||
)
|
||||
res = re.sub(r'(<box>.*</box>)', '', answer)
|
||||
res = res.replace('<ref>', '')
|
||||
res = res.replace('</ref>', '')
|
||||
res = res.replace('<box>', '')
|
||||
answer = res.replace('</box>', '')
|
||||
print('answer:', answer)
|
||||
return 0, answer, None, None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
traceback.print_exc()
|
||||
return -1, ERROR_MSG, None, None
|
||||
|
||||
|
||||
def encode_image(image):
|
||||
if not isinstance(image, Image.Image):
|
||||
if hasattr(image, 'path'):
|
||||
image = Image.open(image.path).convert("RGB")
|
||||
else:
|
||||
image = Image.open(image.file.path).convert("RGB")
|
||||
# resize to max_size
|
||||
max_size = 448*16
|
||||
if max(image.size) > max_size:
|
||||
w,h = image.size
|
||||
if w > h:
|
||||
new_w = max_size
|
||||
new_h = int(h * max_size / w)
|
||||
else:
|
||||
new_h = max_size
|
||||
new_w = int(w * max_size / h)
|
||||
image = image.resize((new_w, new_h), resample=Image.BICUBIC)
|
||||
return image
|
||||
## save by BytesIO and convert to base64
|
||||
#buffered = io.BytesIO()
|
||||
#image.save(buffered, format="png")
|
||||
#im_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
#return {"type": "image", "pairs": im_b64}
|
||||
|
||||
|
||||
def encode_video(video):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
if hasattr(video, 'path'):
|
||||
vr = VideoReader(video.path, ctx=cpu(0))
|
||||
else:
|
||||
vr = VideoReader(video.file.path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx)>MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
||||
video = vr.get_batch(frame_idx).asnumpy()
|
||||
video = [Image.fromarray(v.astype('uint8')) for v in video]
|
||||
video = [encode_image(v) for v in video]
|
||||
print('video frames:', len(video))
|
||||
return video
|
||||
|
||||
|
||||
def check_mm_type(mm_file):
|
||||
if hasattr(mm_file, 'path'):
|
||||
path = mm_file.path
|
||||
else:
|
||||
path = mm_file.file.path
|
||||
if is_image(path):
|
||||
return "image"
|
||||
if is_video(path):
|
||||
return "video"
|
||||
return None
|
||||
|
||||
|
||||
def encode_mm_file(mm_file):
|
||||
if check_mm_type(mm_file) == 'image':
|
||||
return [encode_image(mm_file)]
|
||||
if check_mm_type(mm_file) == 'video':
|
||||
return encode_video(mm_file)
|
||||
return None
|
||||
|
||||
def make_text(text):
|
||||
#return {"type": "text", "pairs": text} # # For remote call
|
||||
return text
|
||||
|
||||
def encode_message(_question):
|
||||
files = _question.files
|
||||
question = _question.text
|
||||
pattern = r"\[mm_media\]\d+\[/mm_media\]"
|
||||
matches = re.split(pattern, question)
|
||||
message = []
|
||||
if len(matches) != len(files) + 1:
|
||||
gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!")
|
||||
assert len(matches) == len(files) + 1
|
||||
|
||||
text = matches[0].strip()
|
||||
if text:
|
||||
message.append(make_text(text))
|
||||
for i in range(len(files)):
|
||||
message += encode_mm_file(files[i])
|
||||
text = matches[i + 1].strip()
|
||||
if text:
|
||||
message.append(make_text(text))
|
||||
return message
|
||||
|
||||
|
||||
def check_has_videos(_question):
|
||||
images_cnt = 0
|
||||
videos_cnt = 0
|
||||
for file in _question.files:
|
||||
if check_mm_type(file) == "image":
|
||||
images_cnt += 1
|
||||
else:
|
||||
videos_cnt += 1
|
||||
return images_cnt, videos_cnt
|
||||
|
||||
|
||||
def count_video_frames(_context):
|
||||
num_frames = 0
|
||||
for message in _context:
|
||||
for item in message["content"]:
|
||||
#if item["type"] == "image": # For remote call
|
||||
if isinstance(item, Image.Image):
|
||||
num_frames += 1
|
||||
return num_frames
|
||||
|
||||
|
||||
def respond(_question, _chat_bot, _app_cfg, params_form):
|
||||
_context = _app_cfg['ctx'].copy()
|
||||
_context.append({'role': 'user', 'content': encode_message(_question)})
|
||||
|
||||
images_cnt = _app_cfg['images_cnt']
|
||||
videos_cnt = _app_cfg['videos_cnt']
|
||||
files_cnts = check_has_videos(_question)
|
||||
if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0):
|
||||
gr.Warning("Only supports single video file input right now!")
|
||||
return _question, _chat_bot, _app_cfg
|
||||
|
||||
if params_form == 'Beam Search':
|
||||
params = {
|
||||
'sampling': False,
|
||||
'num_beams': 3,
|
||||
'repetition_penalty': 1.2,
|
||||
"max_new_tokens": 2048
|
||||
}
|
||||
else:
|
||||
params = {
|
||||
'sampling': True,
|
||||
'top_p': 0.8,
|
||||
'top_k': 100,
|
||||
'temperature': 0.7,
|
||||
'repetition_penalty': 1.05,
|
||||
"max_new_tokens": 2048
|
||||
}
|
||||
|
||||
if files_cnts[1] + videos_cnt > 0:
|
||||
params["max_inp_length"] = 4352 # 4096+256
|
||||
params["use_image_id"] = False
|
||||
params["max_slice_nums"] = 1 if count_video_frames(_context) > 16 else 2
|
||||
|
||||
code, _answer, _, sts = chat("", _context, None, params)
|
||||
|
||||
images_cnt += files_cnts[0]
|
||||
videos_cnt += files_cnts[1]
|
||||
_context.append({"role": "assistant", "content": [make_text(_answer)]})
|
||||
_chat_bot.append((_question, _answer))
|
||||
if code == 0:
|
||||
_app_cfg['ctx']=_context
|
||||
_app_cfg['sts']=sts
|
||||
_app_cfg['images_cnt'] = images_cnt
|
||||
_app_cfg['videos_cnt'] = videos_cnt
|
||||
|
||||
upload_image_disabled = videos_cnt > 0
|
||||
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
|
||||
return create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg
|
||||
|
||||
|
||||
def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg):
|
||||
ctx = _app_cfg["ctx"]
|
||||
message_item = []
|
||||
if _image is not None:
|
||||
image = Image.open(_image).convert("RGB")
|
||||
ctx.append({"role": "user", "content": [encode_image(image), make_text(_user_message)]})
|
||||
message_item.append({"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]})
|
||||
else:
|
||||
if _user_message:
|
||||
ctx.append({"role": "user", "content": [make_text(_user_message)]})
|
||||
message_item.append({"text": _user_message, "files": []})
|
||||
else:
|
||||
message_item.append(None)
|
||||
if _assistant_message:
|
||||
ctx.append({"role": "assistant", "content": [make_text(_assistant_message)]})
|
||||
message_item.append({"text": _assistant_message, "files": []})
|
||||
else:
|
||||
message_item.append(None)
|
||||
|
||||
_chat_bot.append(message_item)
|
||||
return None, "", "", _chat_bot, _app_cfg
|
||||
|
||||
|
||||
def fewshot_respond(_image, _user_message, _chat_bot, _app_cfg, params_form):
|
||||
user_message_contents = []
|
||||
_context = _app_cfg["ctx"].copy()
|
||||
if _image:
|
||||
image = Image.open(_image).convert("RGB")
|
||||
user_message_contents += [encode_image(image)]
|
||||
if _user_message:
|
||||
user_message_contents += [make_text(_user_message)]
|
||||
if user_message_contents:
|
||||
_context.append({"role": "user", "content": user_message_contents})
|
||||
|
||||
if params_form == 'Beam Search':
|
||||
params = {
|
||||
'sampling': False,
|
||||
'num_beams': 3,
|
||||
'repetition_penalty': 1.2,
|
||||
"max_new_tokens": 2048
|
||||
}
|
||||
else:
|
||||
params = {
|
||||
'sampling': True,
|
||||
'top_p': 0.8,
|
||||
'top_k': 100,
|
||||
'temperature': 0.7,
|
||||
'repetition_penalty': 1.05,
|
||||
"max_new_tokens": 2048
|
||||
}
|
||||
|
||||
code, _answer, _, sts = chat("", _context, None, params)
|
||||
|
||||
_context.append({"role": "assistant", "content": [make_text(_answer)]})
|
||||
|
||||
if _image:
|
||||
_chat_bot.append([
|
||||
{"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]},
|
||||
{"text": _answer, "files": []}
|
||||
])
|
||||
else:
|
||||
_chat_bot.append([
|
||||
{"text": _user_message, "files": [_image]},
|
||||
{"text": _answer, "files": []}
|
||||
])
|
||||
if code == 0:
|
||||
_app_cfg['ctx']=_context
|
||||
_app_cfg['sts']=sts
|
||||
return None, '', '', _chat_bot, _app_cfg
|
||||
|
||||
|
||||
def regenerate_button_clicked(_question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg, params_form):
|
||||
if len(_chat_bot) <= 1 or not _chat_bot[-1][1]:
|
||||
gr.Warning('No question for regeneration.')
|
||||
return '', _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
||||
if _app_cfg["chat_type"] == "Chat":
|
||||
images_cnt = _app_cfg['images_cnt']
|
||||
videos_cnt = _app_cfg['videos_cnt']
|
||||
_question = _chat_bot[-1][0]
|
||||
_chat_bot = _chat_bot[:-1]
|
||||
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
||||
files_cnts = check_has_videos(_question)
|
||||
images_cnt -= files_cnts[0]
|
||||
videos_cnt -= files_cnts[1]
|
||||
_app_cfg['images_cnt'] = images_cnt
|
||||
_app_cfg['videos_cnt'] = videos_cnt
|
||||
upload_image_disabled = videos_cnt > 0
|
||||
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
|
||||
_question, _chat_bot, _app_cfg = respond(_question, _chat_bot, _app_cfg, params_form)
|
||||
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
||||
else:
|
||||
last_message = _chat_bot[-1][0]
|
||||
last_image = None
|
||||
last_user_message = ''
|
||||
if last_message.text:
|
||||
last_user_message = last_message.text
|
||||
if last_message.files:
|
||||
last_image = last_message.files[0].file.path
|
||||
_chat_bot = _chat_bot[:-1]
|
||||
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
||||
_image, _user_message, _assistant_message, _chat_bot, _app_cfg = fewshot_respond(last_image, last_user_message, _chat_bot, _app_cfg, params_form)
|
||||
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
||||
|
||||
|
||||
def flushed():
|
||||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def clear(txt_message, chat_bot, app_session):
|
||||
txt_message.files.clear()
|
||||
txt_message.text = ''
|
||||
chat_bot = copy.deepcopy(init_conversation)
|
||||
app_session['sts'] = None
|
||||
app_session['ctx'] = []
|
||||
app_session['images_cnt'] = 0
|
||||
app_session['videos_cnt'] = 0
|
||||
return create_multimodal_input(), chat_bot, app_session, None, '', ''
|
||||
|
||||
|
||||
def select_chat_type(_tab, _app_cfg):
|
||||
_app_cfg["chat_type"] = _tab
|
||||
return _app_cfg
|
||||
|
||||
|
||||
init_conversation = [
|
||||
[
|
||||
None,
|
||||
{
|
||||
# The first message of bot closes the typewriter.
|
||||
"text": "You can talk to me now",
|
||||
"flushing": False
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
css = """
|
||||
video { height: auto !important; }
|
||||
.example label { font-size: 16px;}
|
||||
"""
|
||||
|
||||
introduction = """
|
||||
|
||||
## Features:
|
||||
1. Chat with single image
|
||||
2. Chat with multiple images
|
||||
3. Chat with video
|
||||
4. In-context few-shot learning
|
||||
|
||||
Click `How to use` tab to see examples.
|
||||
"""
|
||||
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
with gr.Tab(model_name):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=300):
|
||||
gr.Markdown(value=introduction)
|
||||
params_form = create_component(form_radio, comp='Radio')
|
||||
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
|
||||
clear_button = create_component({'value': 'Clear History'}, comp='Button')
|
||||
|
||||
with gr.Column(scale=3, min_width=500):
|
||||
app_session = gr.State({'sts':None,'ctx':[], 'images_cnt': 0, 'videos_cnt': 0, 'chat_type': 'Chat'})
|
||||
chat_bot = mgr.Chatbot(label=f"Chat with {model_name}", value=copy.deepcopy(init_conversation), height=600, flushing=False, bubble_full_width=False)
|
||||
|
||||
with gr.Tab("Chat") as chat_tab:
|
||||
txt_message = create_multimodal_input()
|
||||
chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False)
|
||||
|
||||
txt_message.submit(
|
||||
respond,
|
||||
[txt_message, chat_bot, app_session, params_form],
|
||||
[txt_message, chat_bot, app_session]
|
||||
)
|
||||
|
||||
with gr.Tab("Few Shot") as fewshot_tab:
|
||||
fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
image_input = gr.Image(type="filepath", sources=["upload"])
|
||||
with gr.Column(scale=3):
|
||||
user_message = gr.Textbox(label="User")
|
||||
assistant_message = gr.Textbox(label="Assistant")
|
||||
with gr.Row():
|
||||
add_demonstration_button = gr.Button("Add Example")
|
||||
generate_button = gr.Button(value="Generate", variant="primary")
|
||||
add_demonstration_button.click(
|
||||
fewshot_add_demonstration,
|
||||
[image_input, user_message, assistant_message, chat_bot, app_session],
|
||||
[image_input, user_message, assistant_message, chat_bot, app_session]
|
||||
)
|
||||
generate_button.click(
|
||||
fewshot_respond,
|
||||
[image_input, user_message, chat_bot, app_session, params_form],
|
||||
[image_input, user_message, assistant_message, chat_bot, app_session]
|
||||
)
|
||||
|
||||
chat_tab.select(
|
||||
select_chat_type,
|
||||
[chat_tab_label, app_session],
|
||||
[app_session]
|
||||
)
|
||||
chat_tab.select( # do clear
|
||||
clear,
|
||||
[txt_message, chat_bot, app_session],
|
||||
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
||||
)
|
||||
fewshot_tab.select(
|
||||
select_chat_type,
|
||||
[fewshot_tab_label, app_session],
|
||||
[app_session]
|
||||
)
|
||||
fewshot_tab.select( # do clear
|
||||
clear,
|
||||
[txt_message, chat_bot, app_session],
|
||||
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
||||
)
|
||||
chat_bot.flushed(
|
||||
flushed,
|
||||
outputs=[txt_message]
|
||||
)
|
||||
regenerate.click(
|
||||
regenerate_button_clicked,
|
||||
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session, params_form],
|
||||
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session]
|
||||
)
|
||||
clear_button.click(
|
||||
clear,
|
||||
[txt_message, chat_bot, app_session],
|
||||
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
||||
)
|
||||
|
||||
with gr.Tab("How to use"):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
image_example = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/m_bear2.gif", label='1. Chat with single or multiple images', interactive=False, width=400, elem_classes="example")
|
||||
example2 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/video2.gif", label='2. Chat with video', interactive=False, width=400, elem_classes="example")
|
||||
example3 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/fshot.gif", label='3. Few shot', interactive=False, width=400, elem_classes="example")
|
||||
|
||||
|
||||
# launch
|
||||
demo.launch(share=False, debug=True, show_api=False, server_port=8885, server_name="0.0.0.0")
|
||||
|
||||