mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
update
This commit is contained in:
@@ -14,6 +14,9 @@ from PIL import Image
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
||||
|
||||
@@ -30,6 +33,7 @@ class SupervisedDataset(Dataset):
|
||||
patch_size=14,
|
||||
query_nums=64,
|
||||
batch_vision=False,
|
||||
max_length=None,
|
||||
):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
self.raw_data = raw_data
|
||||
@@ -40,13 +44,13 @@ class SupervisedDataset(Dataset):
|
||||
self.patch_size = patch_size
|
||||
self.query_nums=query_nums
|
||||
self.batch_vision = batch_vision
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.raw_data)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
# try:
|
||||
if 1:
|
||||
try:
|
||||
if isinstance(self.raw_data[i]["image"], str):
|
||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||
@@ -63,6 +67,7 @@ class SupervisedDataset(Dataset):
|
||||
llm_type=self.llm_type,
|
||||
patch_size=self.patch_size,
|
||||
batch_vision=self.batch_vision,
|
||||
max_length=self.max_length
|
||||
)
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
@@ -73,13 +78,12 @@ class SupervisedDataset(Dataset):
|
||||
tgt_sizes=ret["tgt_sizes"],
|
||||
image_bound=ret["image_bound"],
|
||||
)
|
||||
# except:
|
||||
# print(f"data fetch error")
|
||||
# return self.__getitem__(random.randint(0, len(self)))
|
||||
except:
|
||||
logger.error(f"data fetch error")
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
def data_collator(examples, padding_value=0, max_length=2048):
|
||||
def trim_and_pad(seq, batch_first, padding_value):
|
||||
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
|
||||
@@ -118,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048):
|
||||
}
|
||||
|
||||
|
||||
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
|
||||
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=None):
|
||||
"""
|
||||
for single image multi-turn conversation
|
||||
conversation: [{'role': 'user', 'content': 'Describe this image'},
|
||||
@@ -139,6 +143,14 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
|
||||
|
||||
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
|
||||
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
|
||||
if input_ids.shape[-1] > max_length:
|
||||
ids =ids[:max_length]
|
||||
context = context[:max_length]
|
||||
logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated")
|
||||
|
||||
if torch.all(context):
|
||||
logger.error("No tokens available to compute loss.")
|
||||
raise Exception("No tokens available to compute loss.")
|
||||
|
||||
# build target
|
||||
target = torch.full_like(ids, -100, dtype=torch.int32)
|
||||
@@ -164,7 +176,8 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||
if len(image_start_tokens) != len(image_end_tokens):
|
||||
print("image start token != image end tokens")
|
||||
logger.error("image start token != image end tokens")
|
||||
raise Exception("image start token != image end tokens")
|
||||
|
||||
if len(image_start_tokens) > 0:
|
||||
image_bound = torch.hstack(
|
||||
@@ -294,7 +307,6 @@ def conversation_to_ids_qwen2(conversation, tokenizer):
|
||||
return input_ids, context, raw_msg
|
||||
|
||||
|
||||
|
||||
def preprocess(
|
||||
images_dict,
|
||||
conversations,
|
||||
@@ -305,6 +317,7 @@ def preprocess(
|
||||
llm_type=None,
|
||||
patch_size=14,
|
||||
batch_vision=False,
|
||||
max_length=None,
|
||||
):
|
||||
"""
|
||||
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
||||
@@ -369,7 +382,7 @@ def preprocess(
|
||||
conversations[0]["content"] = (
|
||||
image_placeholder + "\n" + conversation[0]["content"]
|
||||
)
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
|
||||
else:
|
||||
pattern = r'<image_\d+>'
|
||||
new_conversations = []
|
||||
@@ -377,16 +390,19 @@ def preprocess(
|
||||
content = conversation['content']
|
||||
parts = re.split(f'({pattern})', content)
|
||||
for i, part in enumerate(parts):
|
||||
if not part.strip():
|
||||
continue
|
||||
if re.match(pattern, part):
|
||||
if part in image_placeholder_dict:
|
||||
parts[i] = image_placeholder_dict[part]
|
||||
else:
|
||||
print(f'Unreplaced image tag: {part}')
|
||||
conversation['content'] = ''.join(parts)
|
||||
raise Exception(f"not found {part} in image dict")
|
||||
conversation['content'] = '\n'.join(parts)
|
||||
new_conversations.append(conversation)
|
||||
conversations = new_conversations
|
||||
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema)
|
||||
input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length)
|
||||
|
||||
if batch_vision:
|
||||
tgt_sizes = []
|
||||
reshape_images = []
|
||||
|
||||
Reference in New Issue
Block a user