From 3b3b9331cbab8f71657623b609eb7cb9228d782d Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Wed, 14 Aug 2024 18:31:28 +0800 Subject: [PATCH] update --- finetune/dataset.py | 42 +++++++++++++++++++++++++++------------ finetune/finetune.py | 3 ++- finetune/finetune_lora.sh | 2 +- finetune/readme.md | 4 +++- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/finetune/dataset.py b/finetune/dataset.py index 232a4c2..5345281 100644 --- a/finetune/dataset.py +++ b/finetune/dataset.py @@ -14,6 +14,9 @@ from PIL import Image from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer +import logging + +logger = logging.getLogger(__name__) llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" @@ -30,6 +33,7 @@ class SupervisedDataset(Dataset): patch_size=14, query_nums=64, batch_vision=False, + max_length=None, ): super(SupervisedDataset, self).__init__() self.raw_data = raw_data @@ -40,13 +44,13 @@ class SupervisedDataset(Dataset): self.patch_size = patch_size self.query_nums=query_nums self.batch_vision = batch_vision + self.max_length = max_length def __len__(self): return len(self.raw_data) def __getitem__(self, i) -> Dict[str, torch.Tensor]: - # try: - if 1: + try: if isinstance(self.raw_data[i]["image"], str): images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") } elif isinstance(self.raw_data[i]["image"], Dict): @@ -63,6 +67,7 @@ class SupervisedDataset(Dataset): llm_type=self.llm_type, patch_size=self.patch_size, batch_vision=self.batch_vision, + max_length=self.max_length ) ret = dict( input_ids=ret["input_ids"], @@ -73,13 +78,12 @@ class SupervisedDataset(Dataset): tgt_sizes=ret["tgt_sizes"], image_bound=ret["image_bound"], ) - # except: - # print(f"data fetch error") - # return self.__getitem__(random.randint(0, len(self))) + except: + logger.error(f"data fetch error") + return self.__getitem__(random.randint(0, len(self))) return ret - def data_collator(examples, padding_value=0, max_length=2048): def trim_and_pad(seq, batch_first, padding_value): return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value) @@ -118,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048): } -def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False): +def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=None): """ for single image multi-turn conversation conversation: [{'role': 'user', 'content': 'Describe this image'}, @@ -139,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32)) context = torch.from_numpy(np.hstack(context, dtype=np.int8)) + if input_ids.shape[-1] > max_length: + ids =ids[:max_length] + context = context[:max_length] + logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated") + + if torch.all(context): + logger.error("No tokens available to compute loss.") + raise Exception("No tokens available to compute loss.") # build target target = torch.full_like(ids, -100, dtype=torch.int32) @@ -164,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False image_start_tokens += 1 image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0] if len(image_start_tokens) != len(image_end_tokens): - print("image start token != image end tokens") + logger.error("image start token != image end tokens") + raise Exception("image start token != image end tokens") if len(image_start_tokens) > 0: image_bound = torch.hstack( @@ -294,7 +307,6 @@ def conversation_to_ids_qwen2(conversation, tokenizer): return input_ids, context, raw_msg - def preprocess( images_dict, conversations, @@ -305,6 +317,7 @@ def preprocess( llm_type=None, patch_size=14, batch_vision=False, + max_length=None, ): """ single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation @@ -369,7 +382,7 @@ def preprocess( conversations[0]["content"] = ( image_placeholder + "\n" + conversation[0]["content"] ) - input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema) + input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) else: pattern = r'' new_conversations = [] @@ -377,16 +390,19 @@ def preprocess( content = conversation['content'] parts = re.split(f'({pattern})', content) for i, part in enumerate(parts): + if not part.strip(): + continue if re.match(pattern, part): if part in image_placeholder_dict: parts[i] = image_placeholder_dict[part] else: - print(f'Unreplaced image tag: {part}') - conversation['content'] = ''.join(parts) + raise Exception(f"not found {part} in image dict") + conversation['content'] = '\n'.join(parts) new_conversations.append(conversation) conversations = new_conversations - input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema) + input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) + if batch_vision: tgt_sizes = [] reshape_images = [] diff --git a/finetune/finetune.py b/finetune/finetune.py index d9e0403..0c596d7 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -108,6 +108,7 @@ def make_supervised_data_module( patch_size=patch_size, query_nums=query_nums, batch_vision=batch_vision, + max_length=max_length, ) if data_args.eval_data_path: @@ -121,6 +122,7 @@ def make_supervised_data_module( patch_size=patch_size, query_nums=query_nums, batch_vision=batch_vision, + max_length=max_length, ) else: eval_dataset = None @@ -205,7 +207,6 @@ def train(): tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=True ) - tokenizer.model_max_length = training_args.model_max_length if not training_args.tune_vision: model.vpm.requires_grad_(False) diff --git a/finetune/finetune_lora.sh b/finetune/finetune_lora.sh index a876d9a..19437b1 100644 --- a/finetune/finetune_lora.sh +++ b/finetune/finetune_lora.sh @@ -15,7 +15,7 @@ LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm #if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3 -MODEL_MAX_Length=4096 # if use openbmb/MiniCPM-V-2 or openbmb/MiniCPM-Llama3-V-2_5, please set MODEL_MAX_Length=2048 +MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096 DISTRIBUTED_ARGS=" --nproc_per_node $GPUS_PER_NODE \ diff --git a/finetune/readme.md b/finetune/readme.md index aa3744a..47deb62 100644 --- a/finetune/readme.md +++ b/finetune/readme.md @@ -53,7 +53,9 @@ If your input consists of a single image, you can use a single placeholder **\ #### Multiple Images Example -For inputs with multiple images, you should use a dictionary where each key represents a unique placeholder (e.g., **\**, **\**), and the corresponding value is the image path. You can then use these placeholders in the conversation to insert the images at specific positions. +For inputs containing multiple images, utilize a dictionary where each key represents a unique placeholder (e.g., \textbf{\textbackslash image\_00}, \textbf{\textbackslash image\_01}) with the corresponding image path as its value. These placeholders can then be used within the conversation to seamlessly insert images at specific positions. + +Additionally, to optimize resource management, especially when dealing with large batches of images during training or inference, consider reducing \texttt{max\_slice\_nums}. If you are performing multi-image supervised fine-tuning (SFT), it's recommended to set \texttt{MODEL\_MAX\_LENGTH=4096} in your script for better performance.