Update LoRA finetuning code (#154)

* update lora tuning

* updata lora fine-tuning code

* update finetuning lora code

* lora code

* lora finetuning code

* updating lora finetuning code

* update lora finetuning code

* Update Lora finetuning code

* Update LoRA finetuning code

* Update LoRA finetuning code
This commit is contained in:
qianyu chen
2024-05-27 19:02:59 +08:00
committed by GitHub
parent 2b572c9221
commit 7e12387362
7 changed files with 261 additions and 32 deletions

View File

@@ -56,6 +56,7 @@ class SupervisedDataset(Dataset):
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
@@ -72,6 +73,11 @@ def data_collator(examples, padding_value=0):
batch_first=True,
padding_value=padding_value,
)
position_ids = pad_sequence(
[example["position_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
targets = pad_sequence(
[example["labels"] for example in examples],
batch_first=True,
@@ -87,6 +93,7 @@ def data_collator(examples, padding_value=0):
tgt_sizes = [example["tgt_sizes"] for example in examples]
return {
"input_ids": input_ids,
"position_ids": position_ids,
"labels": targets,
"attention_mask": attention_mask,
"image_bound": image_bound,
@@ -130,6 +137,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
@@ -137,11 +145,13 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
else:
image_bound = []
position_ids = torch.where(ids != 0, torch.arange(ids.size(0)), torch.tensor(0)).long()
return {
"input_ids": ids,
"target": target,
"image_bound": image_bound,
"raw_msg": raw_msg,
"position_ids": position_ids
}