Model Fine-tuning Memory Usage Statistics (#160)

This commit is contained in:
qianyu chen
2024-05-28 11:41:27 +08:00
committed by GitHub
parent 7e12387362
commit f592fedb2e
4 changed files with 30 additions and 10 deletions

View File

@@ -66,24 +66,26 @@ class SupervisedDataset(Dataset):
return ret
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
def data_collator(examples, padding_value=0):
input_ids = pad_sequence(
input_ids = trim_and_pad(
[example["input_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
position_ids = pad_sequence(
position_ids = trim_and_pad(
[example["position_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
targets = pad_sequence(
targets = trim_and_pad(
[example["labels"] for example in examples],
batch_first=True,
padding_value=padding_value,
padding_value=-100,
)
attention_mask = pad_sequence(
attention_mask = trim_and_pad(
[example["attention_mask"] for example in examples],
batch_first=True,
padding_value=padding_value,