mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
update
This commit is contained in:
@@ -14,6 +14,9 @@ from PIL import Image
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
||||
|
||||
@@ -30,6 +33,7 @@ class SupervisedDataset(Dataset):
|
||||
patch_size=14,
|
||||
query_nums=64,
|
||||
batch_vision=False,
|
||||
max_length=None,
|
||||
):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
self.raw_data = raw_data
|
||||
@@ -40,13 +44,13 @@ class SupervisedDataset(Dataset):
|
||||
self.patch_size = patch_size
|
||||
self.query_nums=query_nums
|
||||
self.batch_vision = batch_vision
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.raw_data)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
# try:
|
||||
if 1:
|
||||
try:
|
||||
if isinstance(self.raw_data[i]["image"], str):
|
||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||
@@ -63,6 +67,7 @@ class SupervisedDataset(Dataset):
|
||||
llm_type=self.llm_type,
|
||||
patch_size=self.patch_size,
|
||||
batch_vision=self.batch_vision,
|
||||
max_length=self.max_length
|
||||
)
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
@@ -73,13 +78,12 @@ class SupervisedDataset(Dataset):
|
||||
tgt_sizes=ret["tgt_sizes"],
|
||||
image_bound=ret["image_bound"],
|
||||
)
|
||||
# except:
|
||||
# print(f"data fetch error")
|
||||
# return self.__getitem__(random.randint(0, len(self)))
|
||||
except:
|
||||
logger.error(f"data fetch error")
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
def data_collator(examples, padding_value=0, max_length=2048):
|
||||
def trim_and_pad(seq, batch_first, padding_value):
|
||||
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
|
||||
@@ -118,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
|
||||
}
|
||||
|
||||
|
||||
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
|
||||
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=None):
|
||||
"""
|
||||
for single image multi-turn conversation
|
||||
conversation: [{'role': 'user', 'content': 'Describe this image'},
|
||||
@@ -139,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
|
||||
|
||||
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
|
||||
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
|
||||
if input_ids.shape[-1] > max_length:
|
||||
ids =ids[:max_length]
|
||||
context = context[:max_length]
|
||||
logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
|
||||
|
||||
if torch.all(context):
|
||||
logger.error("No tokens available to compute loss.")
|
||||
raise Exception("No tokens available to compute loss.")
|
||||
|
||||
# build target
|
||||
target = torch.full_like(ids, -100, dtype=torch.int32)
|
||||
@@ -164,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||
if len(image_start_tokens) != len(image_end_tokens):
|
||||
print("image start token != image end tokens")
|
||||
logger.error("image start token != image end tokens")
|
||||
raise Exception("image start token != image end tokens")
|
||||
|
||||
if len(image_start_tokens) > 0:
|
||||
image_bound = torch.hstack(
|
||||
@@ -294,7 +307,6 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
|
||||
return input_ids, context, raw_msg
|
||||
|
||||
|
||||
|
||||
def preprocess(
|
||||
images_dict,
|
||||
conversations,
|
||||
@@ -305,6 +317,7 @@ def preprocess(
|
||||
llm_type=None,
|
||||
patch_size=14,
|
||||
batch_vision=False,
|
||||
max_length=None,
|
||||
):
|
||||
"""
|
||||
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
||||
@@ -369,7 +382,7 @@ def preprocess(
|
||||
conversations[0]["content"] = (
|
||||
image_placeholder + "\n" + conversation[0]["content"]
|
||||
)
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
|
||||
else:
|
||||
pattern = r'<image_\d+>'
|
||||
new_conversations = []
|
||||
@@ -377,16 +390,19 @@ def preprocess(
|
||||
content = conversation['content']
|
||||
parts = re.split(f'({pattern})', content)
|
||||
for i, part in enumerate(parts):
|
||||
if not part.strip():
|
||||
continue
|
||||
if re.match(pattern, part):
|
||||
if part in image_placeholder_dict:
|
||||
parts[i] = image_placeholder_dict[part]
|
||||
else:
|
||||
print(f'Unreplaced image tag: {part}')
|
||||
conversation['content'] = ''.join(parts)
|
||||
raise Exception(f"not found {part} in image dict")
|
||||
conversation['content'] = '\n'.join(parts)
|
||||
new_conversations.append(conversation)
|
||||
conversations = new_conversations
|
||||
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
|
||||
|
||||
if batch_vision:
|
||||
tgt_sizes = []
|
||||
reshape_images = []
|
||||
|
||||
@@ -108,6 +108,7 @@ def make_supervised_data_module(
|
||||
patch_size=patch_size,
|
||||
query_nums=query_nums,
|
||||
batch_vision=batch_vision,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
if data_args.eval_data_path:
|
||||
@@ -121,6 +122,7 @@ def make_supervised_data_module(
|
||||
patch_size=patch_size,
|
||||
query_nums=query_nums,
|
||||
batch_vision=batch_vision,
|
||||
max_length=max_length,
|
||||
)
|
||||
else:
|
||||
eval_dataset = None
|
||||
@@ -205,7 +207,6 @@ def train():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
tokenizer.model_max_length = training_args.model_max_length
|
||||
|
||||
if not training_args.tune_vision:
|
||||
model.vpm.requires_grad_(False)
|
||||
|
||||
@@ -15,7 +15,7 @@ LLM_TYPE="qwen2"
|
||||
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
|
||||
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
|
||||
|
||||
MODEL_MAX_Length=4096 # if use openbmb/MiniCPM-V-2 or openbmb/MiniCPM-Llama3-V-2_5, please set MODEL_MAX_Length=2048
|
||||
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $GPUS_PER_NODE \
|
||||
|
||||
@@ -53,7 +53,9 @@ If your input consists of a single image, you can use a single placeholder **\<i
|
||||
</details>
|
||||
|
||||
#### Multiple Images Example
|
||||
For inputs with multiple images, you should use a dictionary where each key represents a unique placeholder (e.g., **\<image_00\>**, **\<image_01\>**), and the corresponding value is the image path. You can then use these placeholders in the conversation to insert the images at specific positions.
|
||||
For inputs containing multiple images, utilize a dictionary where each key represents a unique placeholder (e.g., \textbf{\textbackslash image\_00}, \textbf{\textbackslash image\_01}) with the corresponding image path as its value. These placeholders can then be used within the conversation to seamlessly insert images at specific positions.
|
||||
|
||||
Additionally, to optimize resource management, especially when dealing with large batches of images during training or inference, consider reducing \texttt{max\_slice\_nums}. If you are performing multi-image supervised fine-tuning (SFT), it's recommended to set \texttt{MODEL\_MAX\_LENGTH=4096} in your script for better performance.
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
|
||||
Reference in New Issue
Block a user