Update to MiniCPM-V 2.6
1002
README_en.md
997
README_zh.md
BIN
assets/gif_cases/ai.gif
Normal file
|
After Width: | Height: | Size: 5.5 MiB |
BIN
assets/gif_cases/beer.gif
Normal file
|
After Width: | Height: | Size: 4.4 MiB |
BIN
assets/gif_cases/mb.gif
Normal file
|
After Width: | Height: | Size: 13 MiB |
BIN
assets/gif_cases/rabbit.gif
Normal file
|
After Width: | Height: | Size: 7.7 MiB |
|
Before Width: | Height: | Size: 14 MiB After Width: | Height: | Size: 10 MiB |
BIN
assets/gif_cases/wfh.gif
Normal file
|
After Width: | Height: | Size: 7.8 MiB |
BIN
assets/gif_cases/zoo.gif
Normal file
|
After Width: | Height: | Size: 3.8 MiB |
BIN
assets/minicpmv2_6/ICL-Mem.png
Normal file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
assets/minicpmv2_6/ICL-elec.png
Normal file
|
After Width: | Height: | Size: 4.4 MiB |
BIN
assets/minicpmv2_6/multi_img-bike.png
Normal file
|
After Width: | Height: | Size: 8.0 MiB |
BIN
assets/minicpmv2_6/multi_img-code.png
Normal file
|
After Width: | Height: | Size: 4.5 MiB |
BIN
assets/minicpmv2_6/multi_img-menu.png
Normal file
|
After Width: | Height: | Size: 3.7 MiB |
BIN
assets/minicpmv2_6/multiling-medal.png
Normal file
|
After Width: | Height: | Size: 2.1 MiB |
BIN
assets/minicpmv2_6/multiling-olympic.png
Normal file
|
After Width: | Height: | Size: 2.6 MiB |
BIN
assets/radar_final.png
Normal file
|
After Width: | Height: | Size: 1.1 MiB |
76
chat.py
@@ -183,13 +183,87 @@ class MiniCPMV2_5:
|
|||||||
)
|
)
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
class MiniCPMV2_6:
|
||||||
|
def __init__(self, model_path, multi_gpus=False) -> None:
|
||||||
|
|
||||||
|
print('torch_version:', torch.__version__)
|
||||||
|
if multi_gpus: # inference on multi-gpus
|
||||||
|
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
|
||||||
|
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
|
||||||
|
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
|
||||||
|
device_id = device_map["llm.model.embed_tokens"]
|
||||||
|
device_map["llm.lm_head"] = device_id # first and last layer of llm should be in the same device
|
||||||
|
device_map["vpm"] = device_id
|
||||||
|
device_map["resampler"] = device_id
|
||||||
|
device_id2 = device_map["llm.model.layers.26"]
|
||||||
|
device_map["llm.model.layers.8"] = device_id2
|
||||||
|
device_map["llm.model.layers.9"] = device_id2
|
||||||
|
device_map["llm.model.layers.10"] = device_id2
|
||||||
|
device_map["llm.model.layers.11"] = device_id2
|
||||||
|
device_map["llm.model.layers.12"] = device_id2
|
||||||
|
device_map["llm.model.layers.13"] = device_id2
|
||||||
|
device_map["llm.model.layers.14"] = device_id2
|
||||||
|
device_map["llm.model.layers.15"] = device_id2
|
||||||
|
device_map["llm.model.layers.16"] = device_id2
|
||||||
|
print(device_map)
|
||||||
|
|
||||||
|
self.model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
|
||||||
|
self.model.eval()
|
||||||
|
else:
|
||||||
|
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
|
||||||
|
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
|
||||||
|
self.model.eval().cuda()
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
def chat(self, input):
|
||||||
|
image = None
|
||||||
|
if "image" in input and len(input["image"]) > 10: # legacy API
|
||||||
|
try:
|
||||||
|
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
|
||||||
|
except Exception as e:
|
||||||
|
return "Image decode error"
|
||||||
|
|
||||||
|
msgs = json.loads(input["question"])
|
||||||
|
|
||||||
|
for msg in msgs:
|
||||||
|
contents = msg.pop('content') # support str or List[Dict]
|
||||||
|
if isinstance(contents, str):
|
||||||
|
contents = [contents]
|
||||||
|
|
||||||
|
new_cnts = []
|
||||||
|
for c in contents:
|
||||||
|
if isinstance(c, dict):
|
||||||
|
if c['type'] == 'text':
|
||||||
|
c = c['pairs']
|
||||||
|
elif c['type'] == 'image':
|
||||||
|
c = Image.open(io.BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
|
||||||
|
else:
|
||||||
|
raise ValueError("content type only support text and image.")
|
||||||
|
new_cnts.append(c)
|
||||||
|
msg['content'] = new_cnts
|
||||||
|
print(f'msgs: {str(msgs)}')
|
||||||
|
|
||||||
|
answer = self.model.chat(
|
||||||
|
image=image,
|
||||||
|
msgs=msgs,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMVChat:
|
class MiniCPMVChat:
|
||||||
def __init__(self, model_path) -> None:
|
def __init__(self, model_path, multi_gpus=False) -> None:
|
||||||
if '12B' in model_path:
|
if '12B' in model_path:
|
||||||
self.model = OmniLMM12B(model_path)
|
self.model = OmniLMM12B(model_path)
|
||||||
elif 'MiniCPM-Llama3-V' in model_path:
|
elif 'MiniCPM-Llama3-V' in model_path:
|
||||||
self.model = MiniCPMV2_5(model_path)
|
self.model = MiniCPMV2_5(model_path)
|
||||||
|
elif 'MiniCPM-V-2_6' in model_path:
|
||||||
|
self.model = MiniCPMV2_6(model_path, multi_gpus)
|
||||||
else:
|
else:
|
||||||
self.model = MiniCPMV(model_path)
|
self.model = MiniCPMV(model_path)
|
||||||
|
|
||||||
|
|||||||
30
docs/faqs.md
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
### FAQs
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Q: How to choose between sampling or beam search for inference </summary>
|
||||||
|
|
||||||
|
In various scenarios, the quality of results obtained from beam search and sampling decoding strategies can vary. You can determine your decoding strategy based on the following aspects:
|
||||||
|
|
||||||
|
If you have the following needs, consider using sampling decoding:
|
||||||
|
|
||||||
|
1. You require faster inference speed.
|
||||||
|
2. You wish for a streaming generation approach.
|
||||||
|
3. Your task necessitates some open-ended responses.
|
||||||
|
|
||||||
|
If your task is about providing deterministic answers, you might want to experiment with beam search to see if it can achieve better outcomes.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Q: How to ensure that the model generates results of sufficient length</summary>
|
||||||
|
|
||||||
|
We've observed that during multi-language inference on MiniCPM-V 2.6, the generation sometimes ends prematurely. You can improve the results by passing a `min_new_tokens` parameter.
|
||||||
|
```python
|
||||||
|
res = model.chat(
|
||||||
|
image=None,
|
||||||
|
msgs=msgs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
min_new_tokens=100
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</details>
|
||||||
@@ -105,7 +105,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
|
||||||
"""
|
"""
|
||||||
for single image multi-turn conversation
|
for single image multi-turn conversation
|
||||||
conversation: [{'role': 'user', 'content': 'Describe this image'},
|
conversation: [{'role': 'user', 'content': 'Describe this image'},
|
||||||
@@ -115,6 +115,10 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
|||||||
input_ids, context, raw_msg = conversation_to_ids_llama3(
|
input_ids, context, raw_msg = conversation_to_ids_llama3(
|
||||||
conversation, tokenizer
|
conversation, tokenizer
|
||||||
)
|
)
|
||||||
|
elif llm_type == "qwen2":
|
||||||
|
input_ids, context, raw_msg = conversation_to_ids_qwen2(
|
||||||
|
conversation, tokenizer
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_ids, context, raw_msg = conversation_to_ids_minicpm(
|
input_ids, context, raw_msg = conversation_to_ids_minicpm(
|
||||||
conversation, tokenizer
|
conversation, tokenizer
|
||||||
@@ -125,6 +129,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
|||||||
|
|
||||||
# build target
|
# build target
|
||||||
target = torch.full_like(ids, -100, dtype=torch.int32)
|
target = torch.full_like(ids, -100, dtype=torch.int32)
|
||||||
|
|
||||||
for i in range(1, len(ids)):
|
for i in range(1, len(ids)):
|
||||||
if context[i] == 0:
|
if context[i] == 0:
|
||||||
target[i - 1] = ids[i]
|
target[i - 1] = ids[i]
|
||||||
@@ -133,14 +138,21 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
|||||||
target[i - 1] = tokenizer.eot_id
|
target[i - 1] = tokenizer.eot_id
|
||||||
else:
|
else:
|
||||||
target[i - 1] = tokenizer.eos_id
|
target[i - 1] = tokenizer.eos_id
|
||||||
|
|
||||||
# build image bound
|
# build image bound
|
||||||
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
|
if new_schema:
|
||||||
image_start_tokens += 1
|
start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id)
|
||||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id)
|
||||||
|
image_start_tokens = torch.where(start_cond)[0]
|
||||||
|
image_start_tokens += 1
|
||||||
|
image_end_tokens = torch.where(end_cond)[0]
|
||||||
|
else:
|
||||||
|
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
|
||||||
|
image_start_tokens += 1
|
||||||
|
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||||
if len(image_start_tokens) != len(image_end_tokens):
|
if len(image_start_tokens) != len(image_end_tokens):
|
||||||
print("image start token != image end tokens")
|
print("image start token != image end tokens")
|
||||||
|
|
||||||
if len(image_start_tokens) > 0:
|
if len(image_start_tokens) > 0:
|
||||||
image_bound = torch.hstack(
|
image_bound = torch.hstack(
|
||||||
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
|
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
|
||||||
@@ -230,6 +242,46 @@ def conversation_to_ids_llama3(conversation, tokenizer):
|
|||||||
return input_ids, context, raw_msg
|
return input_ids, context, raw_msg
|
||||||
|
|
||||||
|
|
||||||
|
def conversation_to_ids_qwen2(conversation, tokenizer):
|
||||||
|
raw_msg = ""
|
||||||
|
chat = []
|
||||||
|
context = []
|
||||||
|
for idx, msg in enumerate(conversation):
|
||||||
|
role = msg["role"]
|
||||||
|
message = msg["content"]
|
||||||
|
assert role in ["user", "assistant"]
|
||||||
|
if role == "user":
|
||||||
|
prefix = "user"
|
||||||
|
else:
|
||||||
|
prefix = "assistant"
|
||||||
|
chat.append({"role":prefix, "content":message})
|
||||||
|
raw_msg += prefix + message
|
||||||
|
assert set([i['role'] for i in chat]) & set(['assistant'])
|
||||||
|
|
||||||
|
ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
|
||||||
|
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
|
||||||
|
input_ids = np.array(input_ids)
|
||||||
|
|
||||||
|
start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
|
||||||
|
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
|
||||||
|
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]
|
||||||
|
|
||||||
|
context = np.ones_like(input_ids, dtype=np.int8)
|
||||||
|
|
||||||
|
for assistant_idx in assistant_idxs:
|
||||||
|
if assistant_idx-1 in set(start_idxs):
|
||||||
|
st = assistant_idx + 1
|
||||||
|
for end_idx in end_idxs:
|
||||||
|
if end_idx > st:
|
||||||
|
context[st: end_idx + 1] = 0
|
||||||
|
break
|
||||||
|
|
||||||
|
input_ids = np.hstack(input_ids)
|
||||||
|
context = np.hstack(context)
|
||||||
|
return input_ids, context, raw_msg
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
image,
|
image,
|
||||||
conversation,
|
conversation,
|
||||||
@@ -256,8 +308,14 @@ def preprocess(
|
|||||||
default_image_placeholder = (
|
default_image_placeholder = (
|
||||||
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
|
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
|
||||||
)
|
)
|
||||||
|
new_schema = False
|
||||||
|
use_image_id = False
|
||||||
|
if llm_type=='qwen2':
|
||||||
|
new_schema = True
|
||||||
|
use_image_id = True
|
||||||
if slice_config:
|
if slice_config:
|
||||||
images = []
|
images = []
|
||||||
|
image_id_cnt = 0
|
||||||
source_image, patches, best_grid = slice_image(
|
source_image, patches, best_grid = slice_image(
|
||||||
image,
|
image,
|
||||||
slice_config["max_slice_nums"],
|
slice_config["max_slice_nums"],
|
||||||
@@ -270,9 +328,11 @@ def preprocess(
|
|||||||
for i in range(len(patches)):
|
for i in range(len(patches)):
|
||||||
for j in range(len(patches[0])):
|
for j in range(len(patches[0])):
|
||||||
images.append(patches[i][j])
|
images.append(patches[i][j])
|
||||||
|
if use_image_id:
|
||||||
|
image_placeholder = f'{tokenizer.im_id_start}{idx}{tokenizer.im_id_end}' + image_placeholder
|
||||||
|
image_id_cnt += 1
|
||||||
image_placeholder += get_grid_placeholder(
|
image_placeholder += get_grid_placeholder(
|
||||||
tokenizer, best_grid, query_nums)
|
tokenizer, best_grid, query_nums, new_schema = new_schema)
|
||||||
images = [transform(i) for i in images]
|
images = [transform(i) for i in images]
|
||||||
else:
|
else:
|
||||||
images = [transform(image)]
|
images = [transform(image)]
|
||||||
@@ -286,7 +346,7 @@ def preprocess(
|
|||||||
image_placeholder + "\n" + conversation[0]["content"]
|
image_placeholder + "\n" + conversation[0]["content"]
|
||||||
)
|
)
|
||||||
|
|
||||||
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
|
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)
|
||||||
|
|
||||||
if batch_vision:
|
if batch_vision:
|
||||||
tgt_sizes = []
|
tgt_sizes = []
|
||||||
@@ -424,7 +484,7 @@ def split_to_patches(image, grid):
|
|||||||
return patches
|
return patches
|
||||||
|
|
||||||
|
|
||||||
def get_grid_placeholder(tokenizer, grid, query_num):
|
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
|
||||||
image_placeholder = (
|
image_placeholder = (
|
||||||
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
|
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
|
||||||
)
|
)
|
||||||
@@ -437,7 +497,10 @@ def get_grid_placeholder(tokenizer, grid, query_num):
|
|||||||
for j in range(cols):
|
for j in range(cols):
|
||||||
lines.append(image_placeholder)
|
lines.append(image_placeholder)
|
||||||
slices.append("".join(lines))
|
slices.append("".join(lines))
|
||||||
slice_placeholder = tokenizer.slice_start + \
|
if new_schema:
|
||||||
|
slice_placeholder = '\n'.join(slices)
|
||||||
|
else:
|
||||||
|
slice_placeholder = tokenizer.slice_start + \
|
||||||
"\n".join(slices) + tokenizer.slice_end
|
"\n".join(slices) + tokenizer.slice_end
|
||||||
return slice_placeholder
|
return slice_placeholder
|
||||||
|
|
||||||
@@ -455,4 +518,4 @@ def reshape_by_patch(image_tensor, patch_size):
|
|||||||
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
|
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
|
||||||
patches = patches.permute(0, 1, 3, 2).reshape(
|
patches = patches.permute(0, 1, 3, 2).reshape(
|
||||||
image_tensor.size(0), patch_size, -1)
|
image_tensor.size(0), patch_size, -1)
|
||||||
return patches
|
return patches
|
||||||
@@ -6,6 +6,8 @@ from dataclasses import dataclass, field
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Union, Literal, Tuple
|
from typing import Dict, List, Optional, Union, Literal, Tuple
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate.utils import DistributedType
|
from accelerate.utils import DistributedType
|
||||||
@@ -130,6 +132,18 @@ def make_supervised_data_module(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_transform():
|
||||||
|
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
|
||||||
|
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
|
||||||
|
return transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def get_parameter_number(model):
|
def get_parameter_number(model):
|
||||||
trainable_params, all_param = 0, 0
|
trainable_params, all_param = 0, 0
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
@@ -248,10 +262,11 @@ def train():
|
|||||||
else:
|
else:
|
||||||
batch_vision = False
|
batch_vision = False
|
||||||
|
|
||||||
|
transform_func = build_transform()
|
||||||
data_module = make_supervised_data_module(
|
data_module = make_supervised_data_module(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
transform=model.transform,
|
transform=transform_func,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
slice_config=slice_config,
|
slice_config=slice_config,
|
||||||
llm_type=llm_type,
|
llm_type=llm_type,
|
||||||
|
|||||||
@@ -6,12 +6,15 @@ NODE_RANK=0
|
|||||||
MASTER_ADDR=localhost
|
MASTER_ADDR=localhost
|
||||||
MASTER_PORT=6001
|
MASTER_PORT=6001
|
||||||
|
|
||||||
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
MODEL="openbmb/MiniCPM-V-2_6"
|
||||||
|
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
|
||||||
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
||||||
# See the section for finetuning in README for more information.
|
# See the section for finetuning in README for more information.
|
||||||
DATA="path/to/trainging_data"
|
DATA="path/to/trainging_data"
|
||||||
EVAL_DATA="path/to/test_data"
|
EVAL_DATA="path/to/test_data"
|
||||||
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DISTRIBUTED_ARGS="
|
DISTRIBUTED_ARGS="
|
||||||
--nproc_per_node $GPUS_PER_NODE \
|
--nproc_per_node $GPUS_PER_NODE \
|
||||||
@@ -28,10 +31,10 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
|||||||
--remove_unused_columns false \
|
--remove_unused_columns false \
|
||||||
--label_names "labels" \
|
--label_names "labels" \
|
||||||
--prediction_loss_only false \
|
--prediction_loss_only false \
|
||||||
--bf16 false \
|
--bf16 true \
|
||||||
--bf16_full_eval false \
|
--bf16_full_eval true \
|
||||||
--fp16 true \
|
--fp16 false \
|
||||||
--fp16_full_eval true \
|
--fp16_full_eval false \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--tune_vision true \
|
--tune_vision true \
|
||||||
@@ -40,8 +43,8 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
|||||||
--max_slice_nums 9 \
|
--max_slice_nums 9 \
|
||||||
--max_steps 10000 \
|
--max_steps 10000 \
|
||||||
--eval_steps 1000 \
|
--eval_steps 1000 \
|
||||||
--output_dir output/output_minicpmv2 \
|
--output_dir output/output_minicpmv26 \
|
||||||
--logging_dir output/output_minicpmv2 \
|
--logging_dir output/output_minicpmv26 \
|
||||||
--logging_strategy "steps" \
|
--logging_strategy "steps" \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
--per_device_eval_batch_size 1 \
|
--per_device_eval_batch_size 1 \
|
||||||
|
|||||||
@@ -6,13 +6,14 @@ NODE_RANK=0
|
|||||||
MASTER_ADDR=localhost
|
MASTER_ADDR=localhost
|
||||||
MASTER_PORT=6001
|
MASTER_PORT=6001
|
||||||
|
|
||||||
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
|
||||||
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
|
||||||
# See the section for finetuning in README for more information.
|
# See the section for finetuning in README for more information.
|
||||||
DATA="path/to/trainging_data"
|
DATA="path/to/trainging_data"
|
||||||
EVAL_DATA="path/to/test_data"
|
EVAL_DATA="path/to/test_data"
|
||||||
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
LLM_TYPE="qwen2"
|
||||||
|
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||||
|
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
|
||||||
DISTRIBUTED_ARGS="
|
DISTRIBUTED_ARGS="
|
||||||
--nproc_per_node $GPUS_PER_NODE \
|
--nproc_per_node $GPUS_PER_NODE \
|
||||||
--nnodes $NNODES \
|
--nnodes $NNODES \
|
||||||
@@ -42,12 +43,12 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
|
|||||||
--max_slice_nums 9 \
|
--max_slice_nums 9 \
|
||||||
--max_steps 10000 \
|
--max_steps 10000 \
|
||||||
--eval_steps 1000 \
|
--eval_steps 1000 \
|
||||||
--output_dir output/output_minicpmv2_lora \
|
--output_dir output/output__lora \
|
||||||
--logging_dir output/output_minicpmv2_lora \
|
--logging_dir output/output_lora \
|
||||||
--logging_strategy "steps" \
|
--logging_strategy "steps" \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 1 \
|
||||||
--per_device_eval_batch_size 1 \
|
--per_device_eval_batch_size 1 \
|
||||||
--gradient_accumulation_steps 8 \
|
--gradient_accumulation_steps 1 \
|
||||||
--evaluation_strategy "steps" \
|
--evaluation_strategy "steps" \
|
||||||
--save_strategy "steps" \
|
--save_strategy "steps" \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
|
|||||||
@@ -1,6 +1,76 @@
|
|||||||
# MiniCPM-V Finetuning
|
# MiniCPM-V Finetuning
|
||||||
|
|
||||||
|
|
||||||
|
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-V-2_6**, **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
|
||||||
|
|
||||||
|
### Data preparation
|
||||||
|
|
||||||
|
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
|
||||||
|
|
||||||
|
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>
|
||||||
|
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b>
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
```
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "0",
|
||||||
|
"image": 'path/to/image_0.jpg',
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': '<image>\nHow many desserts are on the white plate?'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'There are three desserts on the white plate.'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'What type of desserts are they?'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'What is the setting of the image?'},
|
||||||
|
{
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'The image is set on a table top with a plate containing the three desserts.'
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### Full-parameter finetuning
|
||||||
|
|
||||||
|
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path and DATA path in the shell scripts.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
|
||||||
|
DATA="path/to/trainging_data" # json file
|
||||||
|
EVAL_DATA="path/to/test_data" # json file
|
||||||
|
```
|
||||||
|
|
||||||
|
To launch your training, run the following script:
|
||||||
|
|
||||||
|
```
|
||||||
|
sh finetune_ds.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Customizing Hyperparameters
|
||||||
|
To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset.
|
||||||
|
# MiniCPM-V Finetuning
|
||||||
|
|
||||||
|
|
||||||
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
|
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
|
||||||
|
|
||||||
### Data preparation
|
### Data preparation
|
||||||
@@ -55,10 +125,10 @@ For the vision-language example with image, you are required to provide **\<imag
|
|||||||
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.
|
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2
|
||||||
DATA="path/to/trainging_data" # json file
|
DATA="path/to/trainging_data" # json file
|
||||||
EVAL_DATA="path/to/test_data" # json file
|
EVAL_DATA="path/to/test_data" # json file
|
||||||
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
|
||||||
```
|
```
|
||||||
|
|
||||||
To launch your training, run the following script:
|
To launch your training, run the following script:
|
||||||
@@ -82,7 +152,7 @@ After training, you could load the model with the path to the adapter. We advise
|
|||||||
```
|
```
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
model_type="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
|
model_type= "openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-Llama3-V-2_5 , openbmb/MiniCPM-V-2
|
||||||
path_to_adapter="path_to_your_fine_tuned_checkpoint"
|
path_to_adapter="path_to_your_fine_tuned_checkpoint"
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(
|
model = AutoModel.from_pretrained(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|||||||
@@ -29,5 +29,7 @@ uvicorn==0.24.0.post1
|
|||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
accelerate==0.30.1
|
accelerate==0.30.1
|
||||||
socksio==1.0.0
|
socksio==1.0.0
|
||||||
gradio
|
gradio==4.22.0
|
||||||
gradio_client
|
gradio_client
|
||||||
|
http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl
|
||||||
|
decord
|
||||||
|
|||||||
557
web_demo_2.6.py
Normal file
@@ -0,0 +1,557 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# encoding: utf-8
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
import gradio as gr
|
||||||
|
from PIL import Image
|
||||||
|
from decord import VideoReader, cpu
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
import re
|
||||||
|
import modelscope_studio as mgr
|
||||||
|
|
||||||
|
|
||||||
|
# README, How to run demo on different devices
|
||||||
|
|
||||||
|
# For Nvidia GPUs.
|
||||||
|
# python web_demo_2.6.py --device cuda
|
||||||
|
|
||||||
|
# For Mac with MPS (Apple silicon or AMD GPUs).
|
||||||
|
# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.6.py --device mps
|
||||||
|
|
||||||
|
# Argparser
|
||||||
|
parser = argparse.ArgumentParser(description='demo')
|
||||||
|
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
|
||||||
|
parser.add_argument('--multi-gpus', action='store_true', default=False, help='use multi-gpus')
|
||||||
|
args = parser.parse_args()
|
||||||
|
device = args.device
|
||||||
|
assert device in ['cuda', 'mps']
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model_path = 'openbmb/MiniCPM-V-2_6'
|
||||||
|
if 'int4' in model_path:
|
||||||
|
if device == 'mps':
|
||||||
|
print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
|
||||||
|
exit()
|
||||||
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
else:
|
||||||
|
if args.multi_gpus:
|
||||||
|
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16)
|
||||||
|
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
|
||||||
|
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
|
||||||
|
device_id = device_map["llm.model.embed_tokens"]
|
||||||
|
device_map["llm.lm_head"] = device_id # firtt and last layer should be in same device
|
||||||
|
device_map["vpm"] = device_id
|
||||||
|
device_map["resampler"] = device_id
|
||||||
|
device_id2 = device_map["llm.model.layers.26"]
|
||||||
|
device_map["llm.model.layers.8"] = device_id2
|
||||||
|
device_map["llm.model.layers.9"] = device_id2
|
||||||
|
device_map["llm.model.layers.10"] = device_id2
|
||||||
|
device_map["llm.model.layers.11"] = device_id2
|
||||||
|
device_map["llm.model.layers.12"] = device_id2
|
||||||
|
device_map["llm.model.layers.13"] = device_id2
|
||||||
|
device_map["llm.model.layers.14"] = device_id2
|
||||||
|
device_map["llm.model.layers.15"] = device_id2
|
||||||
|
device_map["llm.model.layers.16"] = device_id2
|
||||||
|
#print(device_map)
|
||||||
|
|
||||||
|
model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
|
||||||
|
else:
|
||||||
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
|
||||||
|
model = model.to(device=device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ERROR_MSG = "Error, please retry"
|
||||||
|
model_name = 'MiniCPM-V 2.6'
|
||||||
|
MAX_NUM_FRAMES = 64
|
||||||
|
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
||||||
|
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
|
||||||
|
|
||||||
|
def get_file_extension(filename):
|
||||||
|
return os.path.splitext(filename)[1].lower()
|
||||||
|
|
||||||
|
def is_image(filename):
|
||||||
|
return get_file_extension(filename) in IMAGE_EXTENSIONS
|
||||||
|
|
||||||
|
def is_video(filename):
|
||||||
|
return get_file_extension(filename) in VIDEO_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
form_radio = {
|
||||||
|
'choices': ['Beam Search', 'Sampling'],
|
||||||
|
#'value': 'Beam Search',
|
||||||
|
'value': 'Sampling',
|
||||||
|
'interactive': True,
|
||||||
|
'label': 'Decode Type'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_component(params, comp='Slider'):
|
||||||
|
if comp == 'Slider':
|
||||||
|
return gr.Slider(
|
||||||
|
minimum=params['minimum'],
|
||||||
|
maximum=params['maximum'],
|
||||||
|
value=params['value'],
|
||||||
|
step=params['step'],
|
||||||
|
interactive=params['interactive'],
|
||||||
|
label=params['label']
|
||||||
|
)
|
||||||
|
elif comp == 'Radio':
|
||||||
|
return gr.Radio(
|
||||||
|
choices=params['choices'],
|
||||||
|
value=params['value'],
|
||||||
|
interactive=params['interactive'],
|
||||||
|
label=params['label']
|
||||||
|
)
|
||||||
|
elif comp == 'Button':
|
||||||
|
return gr.Button(
|
||||||
|
value=params['value'],
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
|
||||||
|
return mgr.MultimodalInput(upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
|
||||||
|
upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
|
||||||
|
submit_button_props={'label': 'Submit'})
|
||||||
|
|
||||||
|
|
||||||
|
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
|
||||||
|
try:
|
||||||
|
print('msgs:', msgs)
|
||||||
|
answer = model.chat(
|
||||||
|
image=None,
|
||||||
|
msgs=msgs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
res = re.sub(r'(<box>.*</box>)', '', answer)
|
||||||
|
res = res.replace('<ref>', '')
|
||||||
|
res = res.replace('</ref>', '')
|
||||||
|
res = res.replace('<box>', '')
|
||||||
|
answer = res.replace('</box>', '')
|
||||||
|
print('answer:', answer)
|
||||||
|
return 0, answer, None, None
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
traceback.print_exc()
|
||||||
|
return -1, ERROR_MSG, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image):
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
if hasattr(image, 'path'):
|
||||||
|
image = Image.open(image.path).convert("RGB")
|
||||||
|
else:
|
||||||
|
image = Image.open(image.file.path).convert("RGB")
|
||||||
|
# resize to max_size
|
||||||
|
max_size = 448*16
|
||||||
|
if max(image.size) > max_size:
|
||||||
|
w,h = image.size
|
||||||
|
if w > h:
|
||||||
|
new_w = max_size
|
||||||
|
new_h = int(h * max_size / w)
|
||||||
|
else:
|
||||||
|
new_h = max_size
|
||||||
|
new_w = int(w * max_size / h)
|
||||||
|
image = image.resize((new_w, new_h), resample=Image.BICUBIC)
|
||||||
|
return image
|
||||||
|
## save by BytesIO and convert to base64
|
||||||
|
#buffered = io.BytesIO()
|
||||||
|
#image.save(buffered, format="png")
|
||||||
|
#im_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
#return {"type": "image", "pairs": im_b64}
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video(video):
|
||||||
|
def uniform_sample(l, n):
|
||||||
|
gap = len(l) / n
|
||||||
|
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||||
|
return [l[i] for i in idxs]
|
||||||
|
|
||||||
|
if hasattr(video, 'path'):
|
||||||
|
vr = VideoReader(video.path, ctx=cpu(0))
|
||||||
|
else:
|
||||||
|
vr = VideoReader(video.file.path, ctx=cpu(0))
|
||||||
|
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||||
|
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||||
|
if len(frame_idx)>MAX_NUM_FRAMES:
|
||||||
|
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
||||||
|
video = vr.get_batch(frame_idx).asnumpy()
|
||||||
|
video = [Image.fromarray(v.astype('uint8')) for v in video]
|
||||||
|
video = [encode_image(v) for v in video]
|
||||||
|
print('video frames:', len(video))
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def check_mm_type(mm_file):
|
||||||
|
if hasattr(mm_file, 'path'):
|
||||||
|
path = mm_file.path
|
||||||
|
else:
|
||||||
|
path = mm_file.file.path
|
||||||
|
if is_image(path):
|
||||||
|
return "image"
|
||||||
|
if is_video(path):
|
||||||
|
return "video"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def encode_mm_file(mm_file):
|
||||||
|
if check_mm_type(mm_file) == 'image':
|
||||||
|
return [encode_image(mm_file)]
|
||||||
|
if check_mm_type(mm_file) == 'video':
|
||||||
|
return encode_video(mm_file)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def make_text(text):
|
||||||
|
#return {"type": "text", "pairs": text} # # For remote call
|
||||||
|
return text
|
||||||
|
|
||||||
|
def encode_message(_question):
|
||||||
|
files = _question.files
|
||||||
|
question = _question.text
|
||||||
|
pattern = r"\[mm_media\]\d+\[/mm_media\]"
|
||||||
|
matches = re.split(pattern, question)
|
||||||
|
message = []
|
||||||
|
if len(matches) != len(files) + 1:
|
||||||
|
gr.Warning("Number of Images not match the placeholder in text, please refresh the page to restart!")
|
||||||
|
assert len(matches) == len(files) + 1
|
||||||
|
|
||||||
|
text = matches[0].strip()
|
||||||
|
if text:
|
||||||
|
message.append(make_text(text))
|
||||||
|
for i in range(len(files)):
|
||||||
|
message += encode_mm_file(files[i])
|
||||||
|
text = matches[i + 1].strip()
|
||||||
|
if text:
|
||||||
|
message.append(make_text(text))
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def check_has_videos(_question):
|
||||||
|
images_cnt = 0
|
||||||
|
videos_cnt = 0
|
||||||
|
for file in _question.files:
|
||||||
|
if check_mm_type(file) == "image":
|
||||||
|
images_cnt += 1
|
||||||
|
else:
|
||||||
|
videos_cnt += 1
|
||||||
|
return images_cnt, videos_cnt
|
||||||
|
|
||||||
|
|
||||||
|
def count_video_frames(_context):
|
||||||
|
num_frames = 0
|
||||||
|
for message in _context:
|
||||||
|
for item in message["content"]:
|
||||||
|
#if item["type"] == "image": # For remote call
|
||||||
|
if isinstance(item, Image.Image):
|
||||||
|
num_frames += 1
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def respond(_question, _chat_bot, _app_cfg, params_form):
|
||||||
|
_context = _app_cfg['ctx'].copy()
|
||||||
|
_context.append({'role': 'user', 'content': encode_message(_question)})
|
||||||
|
|
||||||
|
images_cnt = _app_cfg['images_cnt']
|
||||||
|
videos_cnt = _app_cfg['videos_cnt']
|
||||||
|
files_cnts = check_has_videos(_question)
|
||||||
|
if files_cnts[1] + videos_cnt > 1 or (files_cnts[1] + videos_cnt == 1 and files_cnts[0] + images_cnt > 0):
|
||||||
|
gr.Warning("Only supports single video file input right now!")
|
||||||
|
return _question, _chat_bot, _app_cfg
|
||||||
|
|
||||||
|
if params_form == 'Beam Search':
|
||||||
|
params = {
|
||||||
|
'sampling': False,
|
||||||
|
'num_beams': 3,
|
||||||
|
'repetition_penalty': 1.2,
|
||||||
|
"max_new_tokens": 2048
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
params = {
|
||||||
|
'sampling': True,
|
||||||
|
'top_p': 0.8,
|
||||||
|
'top_k': 100,
|
||||||
|
'temperature': 0.7,
|
||||||
|
'repetition_penalty': 1.05,
|
||||||
|
"max_new_tokens": 2048
|
||||||
|
}
|
||||||
|
|
||||||
|
if files_cnts[1] + videos_cnt > 0:
|
||||||
|
params["max_inp_length"] = 4352 # 4096+256
|
||||||
|
params["use_image_id"] = False
|
||||||
|
params["max_slice_nums"] = 1 if count_video_frames(_context) > 16 else 2
|
||||||
|
|
||||||
|
code, _answer, _, sts = chat("", _context, None, params)
|
||||||
|
|
||||||
|
images_cnt += files_cnts[0]
|
||||||
|
videos_cnt += files_cnts[1]
|
||||||
|
_context.append({"role": "assistant", "content": [make_text(_answer)]})
|
||||||
|
_chat_bot.append((_question, _answer))
|
||||||
|
if code == 0:
|
||||||
|
_app_cfg['ctx']=_context
|
||||||
|
_app_cfg['sts']=sts
|
||||||
|
_app_cfg['images_cnt'] = images_cnt
|
||||||
|
_app_cfg['videos_cnt'] = videos_cnt
|
||||||
|
|
||||||
|
upload_image_disabled = videos_cnt > 0
|
||||||
|
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
|
||||||
|
return create_multimodal_input(upload_image_disabled, upload_video_disabled), _chat_bot, _app_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def fewshot_add_demonstration(_image, _user_message, _assistant_message, _chat_bot, _app_cfg):
|
||||||
|
ctx = _app_cfg["ctx"]
|
||||||
|
message_item = []
|
||||||
|
if _image is not None:
|
||||||
|
image = Image.open(_image).convert("RGB")
|
||||||
|
ctx.append({"role": "user", "content": [encode_image(image), make_text(_user_message)]})
|
||||||
|
message_item.append({"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]})
|
||||||
|
else:
|
||||||
|
if _user_message:
|
||||||
|
ctx.append({"role": "user", "content": [make_text(_user_message)]})
|
||||||
|
message_item.append({"text": _user_message, "files": []})
|
||||||
|
else:
|
||||||
|
message_item.append(None)
|
||||||
|
if _assistant_message:
|
||||||
|
ctx.append({"role": "assistant", "content": [make_text(_assistant_message)]})
|
||||||
|
message_item.append({"text": _assistant_message, "files": []})
|
||||||
|
else:
|
||||||
|
message_item.append(None)
|
||||||
|
|
||||||
|
_chat_bot.append(message_item)
|
||||||
|
return None, "", "", _chat_bot, _app_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def fewshot_respond(_image, _user_message, _chat_bot, _app_cfg, params_form):
|
||||||
|
user_message_contents = []
|
||||||
|
_context = _app_cfg["ctx"].copy()
|
||||||
|
if _image:
|
||||||
|
image = Image.open(_image).convert("RGB")
|
||||||
|
user_message_contents += [encode_image(image)]
|
||||||
|
if _user_message:
|
||||||
|
user_message_contents += [make_text(_user_message)]
|
||||||
|
if user_message_contents:
|
||||||
|
_context.append({"role": "user", "content": user_message_contents})
|
||||||
|
|
||||||
|
if params_form == 'Beam Search':
|
||||||
|
params = {
|
||||||
|
'sampling': False,
|
||||||
|
'num_beams': 3,
|
||||||
|
'repetition_penalty': 1.2,
|
||||||
|
"max_new_tokens": 2048
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
params = {
|
||||||
|
'sampling': True,
|
||||||
|
'top_p': 0.8,
|
||||||
|
'top_k': 100,
|
||||||
|
'temperature': 0.7,
|
||||||
|
'repetition_penalty': 1.05,
|
||||||
|
"max_new_tokens": 2048
|
||||||
|
}
|
||||||
|
|
||||||
|
code, _answer, _, sts = chat("", _context, None, params)
|
||||||
|
|
||||||
|
_context.append({"role": "assistant", "content": [make_text(_answer)]})
|
||||||
|
|
||||||
|
if _image:
|
||||||
|
_chat_bot.append([
|
||||||
|
{"text": "[mm_media]1[/mm_media]" + _user_message, "files": [_image]},
|
||||||
|
{"text": _answer, "files": []}
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
_chat_bot.append([
|
||||||
|
{"text": _user_message, "files": [_image]},
|
||||||
|
{"text": _answer, "files": []}
|
||||||
|
])
|
||||||
|
if code == 0:
|
||||||
|
_app_cfg['ctx']=_context
|
||||||
|
_app_cfg['sts']=sts
|
||||||
|
return None, '', '', _chat_bot, _app_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def regenerate_button_clicked(_question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg, params_form):
|
||||||
|
if len(_chat_bot) <= 1 or not _chat_bot[-1][1]:
|
||||||
|
gr.Warning('No question for regeneration.')
|
||||||
|
return '', _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
||||||
|
if _app_cfg["chat_type"] == "Chat":
|
||||||
|
images_cnt = _app_cfg['images_cnt']
|
||||||
|
videos_cnt = _app_cfg['videos_cnt']
|
||||||
|
_question = _chat_bot[-1][0]
|
||||||
|
_chat_bot = _chat_bot[:-1]
|
||||||
|
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
||||||
|
files_cnts = check_has_videos(_question)
|
||||||
|
images_cnt -= files_cnts[0]
|
||||||
|
videos_cnt -= files_cnts[1]
|
||||||
|
_app_cfg['images_cnt'] = images_cnt
|
||||||
|
_app_cfg['videos_cnt'] = videos_cnt
|
||||||
|
upload_image_disabled = videos_cnt > 0
|
||||||
|
upload_video_disabled = videos_cnt > 0 or images_cnt > 0
|
||||||
|
_question, _chat_bot, _app_cfg = respond(_question, _chat_bot, _app_cfg, params_form)
|
||||||
|
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
||||||
|
else:
|
||||||
|
last_message = _chat_bot[-1][0]
|
||||||
|
last_image = None
|
||||||
|
last_user_message = ''
|
||||||
|
if last_message.text:
|
||||||
|
last_user_message = last_message.text
|
||||||
|
if last_message.files:
|
||||||
|
last_image = last_message.files[0].file.path
|
||||||
|
_chat_bot = _chat_bot[:-1]
|
||||||
|
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
|
||||||
|
_image, _user_message, _assistant_message, _chat_bot, _app_cfg = fewshot_respond(last_image, last_user_message, _chat_bot, _app_cfg, params_form)
|
||||||
|
return _question, _image, _user_message, _assistant_message, _chat_bot, _app_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def flushed():
|
||||||
|
return gr.update(interactive=True)
|
||||||
|
|
||||||
|
|
||||||
|
def clear(txt_message, chat_bot, app_session):
|
||||||
|
txt_message.files.clear()
|
||||||
|
txt_message.text = ''
|
||||||
|
chat_bot = copy.deepcopy(init_conversation)
|
||||||
|
app_session['sts'] = None
|
||||||
|
app_session['ctx'] = []
|
||||||
|
app_session['images_cnt'] = 0
|
||||||
|
app_session['videos_cnt'] = 0
|
||||||
|
return create_multimodal_input(), chat_bot, app_session, None, '', ''
|
||||||
|
|
||||||
|
|
||||||
|
def select_chat_type(_tab, _app_cfg):
|
||||||
|
_app_cfg["chat_type"] = _tab
|
||||||
|
return _app_cfg
|
||||||
|
|
||||||
|
|
||||||
|
init_conversation = [
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
{
|
||||||
|
# The first message of bot closes the typewriter.
|
||||||
|
"text": "You can talk to me now",
|
||||||
|
"flushing": False
|
||||||
|
}
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
css = """
|
||||||
|
video { height: auto !important; }
|
||||||
|
.example label { font-size: 16px;}
|
||||||
|
"""
|
||||||
|
|
||||||
|
introduction = """
|
||||||
|
|
||||||
|
## Features:
|
||||||
|
1. Chat with single image
|
||||||
|
2. Chat with multiple images
|
||||||
|
3. Chat with video
|
||||||
|
4. In-context few-shot learning
|
||||||
|
|
||||||
|
Click `How to use` tab to see examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks(css=css) as demo:
|
||||||
|
with gr.Tab(model_name):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1, min_width=300):
|
||||||
|
gr.Markdown(value=introduction)
|
||||||
|
params_form = create_component(form_radio, comp='Radio')
|
||||||
|
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
|
||||||
|
clear_button = create_component({'value': 'Clear History'}, comp='Button')
|
||||||
|
|
||||||
|
with gr.Column(scale=3, min_width=500):
|
||||||
|
app_session = gr.State({'sts':None,'ctx':[], 'images_cnt': 0, 'videos_cnt': 0, 'chat_type': 'Chat'})
|
||||||
|
chat_bot = mgr.Chatbot(label=f"Chat with {model_name}", value=copy.deepcopy(init_conversation), height=600, flushing=False, bubble_full_width=False)
|
||||||
|
|
||||||
|
with gr.Tab("Chat") as chat_tab:
|
||||||
|
txt_message = create_multimodal_input()
|
||||||
|
chat_tab_label = gr.Textbox(value="Chat", interactive=False, visible=False)
|
||||||
|
|
||||||
|
txt_message.submit(
|
||||||
|
respond,
|
||||||
|
[txt_message, chat_bot, app_session, params_form],
|
||||||
|
[txt_message, chat_bot, app_session]
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Tab("Few Shot") as fewshot_tab:
|
||||||
|
fewshot_tab_label = gr.Textbox(value="Few Shot", interactive=False, visible=False)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
image_input = gr.Image(type="filepath", sources=["upload"])
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
user_message = gr.Textbox(label="User")
|
||||||
|
assistant_message = gr.Textbox(label="Assistant")
|
||||||
|
with gr.Row():
|
||||||
|
add_demonstration_button = gr.Button("Add Example")
|
||||||
|
generate_button = gr.Button(value="Generate", variant="primary")
|
||||||
|
add_demonstration_button.click(
|
||||||
|
fewshot_add_demonstration,
|
||||||
|
[image_input, user_message, assistant_message, chat_bot, app_session],
|
||||||
|
[image_input, user_message, assistant_message, chat_bot, app_session]
|
||||||
|
)
|
||||||
|
generate_button.click(
|
||||||
|
fewshot_respond,
|
||||||
|
[image_input, user_message, chat_bot, app_session, params_form],
|
||||||
|
[image_input, user_message, assistant_message, chat_bot, app_session]
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_tab.select(
|
||||||
|
select_chat_type,
|
||||||
|
[chat_tab_label, app_session],
|
||||||
|
[app_session]
|
||||||
|
)
|
||||||
|
chat_tab.select( # do clear
|
||||||
|
clear,
|
||||||
|
[txt_message, chat_bot, app_session],
|
||||||
|
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
||||||
|
)
|
||||||
|
fewshot_tab.select(
|
||||||
|
select_chat_type,
|
||||||
|
[fewshot_tab_label, app_session],
|
||||||
|
[app_session]
|
||||||
|
)
|
||||||
|
fewshot_tab.select( # do clear
|
||||||
|
clear,
|
||||||
|
[txt_message, chat_bot, app_session],
|
||||||
|
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
||||||
|
)
|
||||||
|
chat_bot.flushed(
|
||||||
|
flushed,
|
||||||
|
outputs=[txt_message]
|
||||||
|
)
|
||||||
|
regenerate.click(
|
||||||
|
regenerate_button_clicked,
|
||||||
|
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session, params_form],
|
||||||
|
[txt_message, image_input, user_message, assistant_message, chat_bot, app_session]
|
||||||
|
)
|
||||||
|
clear_button.click(
|
||||||
|
clear,
|
||||||
|
[txt_message, chat_bot, app_session],
|
||||||
|
[txt_message, chat_bot, app_session, image_input, user_message, assistant_message]
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Tab("How to use"):
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
image_example = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/m_bear2.gif", label='1. Chat with single or multiple images', interactive=False, width=400, elem_classes="example")
|
||||||
|
example2 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/video2.gif", label='2. Chat with video', interactive=False, width=400, elem_classes="example")
|
||||||
|
example3 = gr.Image(value="http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/fshot.gif", label='3. Few shot', interactive=False, width=400, elem_classes="example")
|
||||||
|
|
||||||
|
|
||||||
|
# launch
|
||||||
|
demo.launch(share=False, debug=True, show_api=False, server_port=8885, server_name="0.0.0.0")
|
||||||
|
|
||||||