mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
update
This commit is contained in:
@@ -14,6 +14,9 @@ from PIL import Image
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import AutoProcessor, AutoTokenizer
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
||||||
|
|
||||||
@@ -30,6 +33,7 @@ class SupervisedDataset(Dataset):
|
|||||||
patch_size=14,
|
patch_size=14,
|
||||||
query_nums=64,
|
query_nums=64,
|
||||||
batch_vision=False,
|
batch_vision=False,
|
||||||
|
max_length=None,
|
||||||
):
|
):
|
||||||
super(SupervisedDataset, self).__init__()
|
super(SupervisedDataset, self).__init__()
|
||||||
self.raw_data = raw_data
|
self.raw_data = raw_data
|
||||||
@@ -40,13 +44,13 @@ class SupervisedDataset(Dataset):
|
|||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.query_nums=query_nums
|
self.query_nums=query_nums
|
||||||
self.batch_vision = batch_vision
|
self.batch_vision = batch_vision
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.raw_data)
|
return len(self.raw_data)
|
||||||
|
|
||||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||||
# try:
|
try:
|
||||||
if 1:
|
|
||||||
if isinstance(self.raw_data[i]["image"], str):
|
if isinstance(self.raw_data[i]["image"], str):
|
||||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||||
@@ -63,6 +67,7 @@ class SupervisedDataset(Dataset):
|
|||||||
llm_type=self.llm_type,
|
llm_type=self.llm_type,
|
||||||
patch_size=self.patch_size,
|
patch_size=self.patch_size,
|
||||||
batch_vision=self.batch_vision,
|
batch_vision=self.batch_vision,
|
||||||
|
max_length=self.max_length
|
||||||
)
|
)
|
||||||
ret = dict(
|
ret = dict(
|
||||||
input_ids=ret["input_ids"],
|
input_ids=ret["input_ids"],
|
||||||
@@ -73,13 +78,12 @@ class SupervisedDataset(Dataset):
|
|||||||
tgt_sizes=ret["tgt_sizes"],
|
tgt_sizes=ret["tgt_sizes"],
|
||||||
image_bound=ret["image_bound"],
|
image_bound=ret["image_bound"],
|
||||||
)
|
)
|
||||||
# except:
|
except:
|
||||||
# print(f"data fetch error")
|
logger.error(f"data fetch error")
|
||||||
# return self.__getitem__(random.randint(0, len(self)))
|
return self.__getitem__(random.randint(0, len(self)))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def data_collator(examples, padding_value=0, max_length=2048):
|
def data_collator(examples, padding_value=0, max_length=2048):
|
||||||
def trim_and_pad(seq, batch_first, padding_value):
|
def trim_and_pad(seq, batch_first, padding_value):
|
||||||
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
|
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
|
||||||
@@ -118,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
|
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=None):
|
||||||
"""
|
"""
|
||||||
for single image multi-turn conversation
|
for single image multi-turn conversation
|
||||||
conversation: [{'role': 'user', 'content': 'Describe this image'},
|
conversation: [{'role': 'user', 'content': 'Describe this image'},
|
||||||
@@ -139,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
|
|||||||
|
|
||||||
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
|
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
|
||||||
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
|
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
|
||||||
|
if input_ids.shape[-1] > max_length:
|
||||||
|
ids =ids[:max_length]
|
||||||
|
context = context[:max_length]
|
||||||
|
logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
|
||||||
|
|
||||||
|
if torch.all(context):
|
||||||
|
logger.error("No tokens available to compute loss.")
|
||||||
|
raise Exception("No tokens available to compute loss.")
|
||||||
|
|
||||||
# build target
|
# build target
|
||||||
target = torch.full_like(ids, -100, dtype=torch.int32)
|
target = torch.full_like(ids, -100, dtype=torch.int32)
|
||||||
@@ -164,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
|
|||||||
image_start_tokens += 1
|
image_start_tokens += 1
|
||||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||||
if len(image_start_tokens) != len(image_end_tokens):
|
if len(image_start_tokens) != len(image_end_tokens):
|
||||||
print("image start token != image end tokens")
|
logger.error("image start token != image end tokens")
|
||||||
|
raise Exception("image start token != image end tokens")
|
||||||
|
|
||||||
if len(image_start_tokens) > 0:
|
if len(image_start_tokens) > 0:
|
||||||
image_bound = torch.hstack(
|
image_bound = torch.hstack(
|
||||||
@@ -294,7 +307,6 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
|
|||||||
return input_ids, context, raw_msg
|
return input_ids, context, raw_msg
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
images_dict,
|
images_dict,
|
||||||
conversations,
|
conversations,
|
||||||
@@ -305,6 +317,7 @@ def preprocess(
|
|||||||
llm_type=None,
|
llm_type=None,
|
||||||
patch_size=14,
|
patch_size=14,
|
||||||
batch_vision=False,
|
batch_vision=False,
|
||||||
|
max_length=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
||||||
@@ -369,7 +382,7 @@ def preprocess(
|
|||||||
conversations[0]["content"] = (
|
conversations[0]["content"] = (
|
||||||
image_placeholder + "\n" + conversation[0]["content"]
|
image_placeholder + "\n" + conversation[0]["content"]
|
||||||
)
|
)
|
||||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
|
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
|
||||||
else:
|
else:
|
||||||
pattern = r'<image_\d+>'
|
pattern = r'<image_\d+>'
|
||||||
new_conversations = []
|
new_conversations = []
|
||||||
@@ -377,16 +390,19 @@ def preprocess(
|
|||||||
content = conversation['content']
|
content = conversation['content']
|
||||||
parts = re.split(f'({pattern})', content)
|
parts = re.split(f'({pattern})', content)
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
|
if not part.strip():
|
||||||
|
continue
|
||||||
if re.match(pattern, part):
|
if re.match(pattern, part):
|
||||||
if part in image_placeholder_dict:
|
if part in image_placeholder_dict:
|
||||||
parts[i] = image_placeholder_dict[part]
|
parts[i] = image_placeholder_dict[part]
|
||||||
else:
|
else:
|
||||||
print(f'Unreplaced image tag: {part}')
|
raise Exception(f"not found {part} in image dict")
|
||||||
conversation['content'] = ''.join(parts)
|
conversation['content'] = '\n'.join(parts)
|
||||||
new_conversations.append(conversation)
|
new_conversations.append(conversation)
|
||||||
conversations = new_conversations
|
conversations = new_conversations
|
||||||
|
|
||||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
|
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
|
||||||
|
|
||||||
if batch_vision:
|
if batch_vision:
|
||||||
tgt_sizes = []
|
tgt_sizes = []
|
||||||
reshape_images = []
|
reshape_images = []
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ def make_supervised_data_module(
|
|||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
query_nums=query_nums,
|
query_nums=query_nums,
|
||||||
batch_vision=batch_vision,
|
batch_vision=batch_vision,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.eval_data_path:
|
if data_args.eval_data_path:
|
||||||
@@ -121,6 +122,7 @@ def make_supervised_data_module(
|
|||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
query_nums=query_nums,
|
query_nums=query_nums,
|
||||||
batch_vision=batch_vision,
|
batch_vision=batch_vision,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
@@ -205,7 +207,6 @@ def train():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path, trust_remote_code=True
|
model_args.model_name_or_path, trust_remote_code=True
|
||||||
)
|
)
|
||||||
tokenizer.model_max_length = training_args.model_max_length
|
|
||||||
|
|
||||||
if not training_args.tune_vision:
|
if not training_args.tune_vision:
|
||||||
model.vpm.requires_grad_(False)
|
model.vpm.requires_grad_(False)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ LLM_TYPE="qwen2"
|
|||||||
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||||
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
|
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
|
||||||
|
|
||||||
MODEL_MAX_Length=4096 # if use openbmb/MiniCPM-V-2 or openbmb/MiniCPM-Llama3-V-2_5, please set MODEL_MAX_Length=2048
|
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
||||||
|
|
||||||
DISTRIBUTED_ARGS="
|
DISTRIBUTED_ARGS="
|
||||||
--nproc_per_node $GPUS_PER_NODE \
|
--nproc_per_node $GPUS_PER_NODE \
|
||||||
|
|||||||
@@ -53,7 +53,9 @@ If your input consists of a single image, you can use a single placeholder **\<i
|
|||||||
</details>
|
</details>
|
||||||
|
|
||||||
#### Multiple Images Example
|
#### Multiple Images Example
|
||||||
For inputs with multiple images, you should use a dictionary where each key represents a unique placeholder (e.g., **\<image_00\>**, **\<image_01\>**), and the corresponding value is the image path. You can then use these placeholders in the conversation to insert the images at specific positions.
|
For inputs containing multiple images, utilize a dictionary where each key represents a unique placeholder (e.g., \textbf{\textbackslash image\_00}, \textbf{\textbackslash image\_01}) with the corresponding image path as its value. These placeholders can then be used within the conversation to seamlessly insert images at specific positions.
|
||||||
|
|
||||||
|
Additionally, to optimize resource management, especially when dealing with large batches of images during training or inference, consider reducing \texttt{max\_slice\_nums}. If you are performing multi-image supervised fine-tuning (SFT), it's recommended to set \texttt{MODEL\_MAX\_LENGTH=4096} in your script for better performance.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>
|
<summary>
|
||||||
|
|||||||
Reference in New Issue
Block a user