This commit is contained in:
qianyu chen
2024-08-14 18:31:28 +08:00
committed by GitHub
parent 7842ec1228
commit 3b3b9331cb
4 changed files with 35 additions and 16 deletions

View File

@@ -14,6 +14,9 @@ from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
import logging
logger = logging.getLogger(__name__)
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
@@ -30,6 +33,7 @@ class SupervisedDataset(Dataset):
patch_size=14,
query_nums=64,
batch_vision=False,
max_length=None,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
@@ -40,13 +44,13 @@ class SupervisedDataset(Dataset):
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
self.max_length = max_length
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# try:
if 1:
try:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
@@ -63,6 +67,7 @@ class SupervisedDataset(Dataset):
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
max_length=self.max_length
)
ret = dict(
input_ids=ret["input_ids"],
@@ -73,13 +78,12 @@ class SupervisedDataset(Dataset):
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
# except:
# print(f"data fetch error")
# return self.__getitem__(random.randint(0, len(self)))
except:
logger.error(f"data fetch error")
return self.__getitem__(random.randint(0, len(self)))
return ret
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
@@ -118,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
}
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=None):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
@@ -139,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
if input_ids.shape[-1] > max_length:
ids =ids[:max_length]
context = context[:max_length]
logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
if torch.all(context):
logger.error("No tokens available to compute loss.")
raise Exception("No tokens available to compute loss.")
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
@@ -164,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
logger.error("image start token != image end tokens")
raise Exception("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
@@ -294,7 +307,6 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
return input_ids, context, raw_msg
def preprocess(
images_dict,
conversations,
@@ -305,6 +317,7 @@ def preprocess(
llm_type=None,
patch_size=14,
batch_vision=False,
max_length=None,
):
"""
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
@@ -369,7 +382,7 @@ def preprocess(
conversations[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
else:
pattern = r'<image_\d+>'
new_conversations = []
@@ -377,16 +390,19 @@ def preprocess(
content = conversation['content']
parts = re.split(f'({pattern})', content)
for i, part in enumerate(parts):
if not part.strip():
continue
if re.match(pattern, part):
if part in image_placeholder_dict:
parts[i] = image_placeholder_dict[part]
else:
print(f'Unreplaced image tag: {part}')
conversation['content'] = ''.join(parts)
raise Exception(f"not found {part} in image dict")
conversation['content'] = '\n'.join(parts)
new_conversations.append(conversation)
conversations = new_conversations
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
if batch_vision:
tgt_sizes = []
reshape_images = []

View File

@@ -108,6 +108,7 @@ def make_supervised_data_module(
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
)
if data_args.eval_data_path:
@@ -121,6 +122,7 @@ def make_supervised_data_module(
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
)
else:
eval_dataset = None
@@ -205,7 +207,6 @@ def train():
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
)
tokenizer.model_max_length = training_args.model_max_length
if not training_args.tune_vision:
model.vpm.requires_grad_(False)

View File

@@ -15,7 +15,7 @@ LLM_TYPE="qwen2"
# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3
MODEL_MAX_Length=4096 # if use openbmb/MiniCPM-V-2 or openbmb/MiniCPM-Llama3-V-2_5, please set MODEL_MAX_Length=2048
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \

View File

@@ -53,7 +53,9 @@ If your input consists of a single image, you can use a single placeholder **\<i
</details>
#### Multiple Images Example
For inputs with multiple images, you should use a dictionary where each key represents a unique placeholder (e.g., **\<image_00\>**, **\<image_01\>**), and the corresponding value is the image path. You can then use these placeholders in the conversation to insert the images at specific positions.
For inputs containing multiple images, utilize a dictionary where each key represents a unique placeholder (e.g., \textbf{\textbackslash image\_00}, \textbf{\textbackslash image\_01}) with the corresponding image path as its value. These placeholders can then be used within the conversation to seamlessly insert images at specific positions.
Additionally, to optimize resource management, especially when dealing with large batches of images during training or inference, consider reducing \texttt{max\_slice\_nums}. If you are performing multi-image supervised fine-tuning (SFT), it's recommended to set \texttt{MODEL\_MAX\_LENGTH=4096} in your script for better performance.
<details>
<summary>