This commit is contained in:
qianyu chen
2024-08-14 18:31:28 +08:00
committed by GitHub
parent 7842ec1228
commit 3b3b9331cb
4 changed files with 35 additions and 16 deletions

View File

@@ -14,6 +14,9 @@ from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
import logging
logger = logging.getLogger(__name__)
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
@@ -30,6 +33,7 @@ class SupervisedDataset(Dataset):
patch_size=14,
query_nums=64,
batch_vision=False,
max_length=None,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
@@ -40,13 +44,13 @@ class SupervisedDataset(Dataset):
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
self.max_length = max_length
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# try:
if 1:
try:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
@@ -63,6 +67,7 @@ class SupervisedDataset(Dataset):
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
max_length=self.max_length
)
ret = dict(
input_ids=ret["input_ids"],
@@ -73,13 +78,12 @@ class SupervisedDataset(Dataset):
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
# except:
# print(f"data fetch error")
# return self.__getitem__(random.randint(0, len(self)))
except:
logger.error(f"data fetch error")
return self.__getitem__(random.randint(0, len(self)))
return ret
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
@@ -118,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
}
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=None):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
@@ -139,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
if input_ids.shape[-1] > max_length:
ids =ids[:max_length]
context = context[:max_length]
logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
if torch.all(context):
logger.error("No tokens available to compute loss.")
raise Exception("No tokens available to compute loss.")
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
@@ -164,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
logger.error("image start token != image end tokens")
raise Exception("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
@@ -294,7 +307,6 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
return input_ids, context, raw_msg
def preprocess(
images_dict,
conversations,
@@ -305,6 +317,7 @@ def preprocess(
llm_type=None,
patch_size=14,
batch_vision=False,
max_length=None,
):
"""
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
@@ -369,7 +382,7 @@ def preprocess(
conversations[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
else:
pattern = r'<image_\d+>'
new_conversations = []
@@ -377,16 +390,19 @@ def preprocess(
content = conversation['content']
parts = re.split(f'({pattern})', content)
for i, part in enumerate(parts):
if not part.strip():
continue
if re.match(pattern, part):
if part in image_placeholder_dict:
parts[i] = image_placeholder_dict[part]
else:
print(f'Unreplaced image tag: {part}')
conversation['content'] = ''.join(parts)
raise Exception(f"not found {part} in image dict")
conversation['content'] = '\n'.join(parts)
new_conversations.append(conversation)
conversations = new_conversations
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
if batch_vision:
tgt_sizes = []
reshape_images = []