mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Update LoRA finetuning code (#154)
* update lora tuning * updata lora fine-tuning code * update finetuning lora code * lora code * lora finetuning code * updating lora finetuning code * update lora finetuning code * Update Lora finetuning code * Update LoRA finetuning code * Update LoRA finetuning code
This commit is contained in:
@@ -56,6 +56,7 @@ class SupervisedDataset(Dataset):
|
||||
)
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
position_ids=ret["position_ids"],
|
||||
labels=ret["target"],
|
||||
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
|
||||
pixel_values=ret["pixel_values"],
|
||||
@@ -72,6 +73,11 @@ def data_collator(examples, padding_value=0):
|
||||
batch_first=True,
|
||||
padding_value=padding_value,
|
||||
)
|
||||
position_ids = pad_sequence(
|
||||
[example["position_ids"] for example in examples],
|
||||
batch_first=True,
|
||||
padding_value=padding_value,
|
||||
)
|
||||
targets = pad_sequence(
|
||||
[example["labels"] for example in examples],
|
||||
batch_first=True,
|
||||
@@ -87,6 +93,7 @@ def data_collator(examples, padding_value=0):
|
||||
tgt_sizes = [example["tgt_sizes"] for example in examples]
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"labels": targets,
|
||||
"attention_mask": attention_mask,
|
||||
"image_bound": image_bound,
|
||||
@@ -130,6 +137,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
|
||||
if len(image_start_tokens) != len(image_end_tokens):
|
||||
print("image start token != image end tokens")
|
||||
|
||||
if len(image_start_tokens) > 0:
|
||||
image_bound = torch.hstack(
|
||||
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
|
||||
@@ -137,11 +145,13 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
|
||||
else:
|
||||
image_bound = []
|
||||
|
||||
position_ids = torch.where(ids != 0, torch.arange(ids.size(0)), torch.tensor(0)).long()
|
||||
return {
|
||||
"input_ids": ids,
|
||||
"target": target,
|
||||
"image_bound": image_bound,
|
||||
"raw_msg": raw_msg,
|
||||
"position_ids": position_ids
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user