Merge pull request #708 from BUAADreamer/main

[Feature] Support video sft and fix some training bugs
This commit is contained in:
qianyu chen
2025-01-14 15:16:08 +08:00
committed by GitHub
4 changed files with 170 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ import re
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from decord import VideoReader, cpu # pip install decord
import numpy as np
import torch
@@ -20,6 +21,26 @@ logger = logging.getLogger(__name__)
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
MAX_NUM_FRAMES=64
def encode_video(video_path, max_num_frames=64):
max_num_frames = min(max_num_frames, MAX_NUM_FRAMES)
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > max_num_frames:
if max_num_frames==1:
frame_idx = [frame_idx[len(frame_idx)//2]]
else:
frame_idx = uniform_sample(frame_idx, max_num_frames)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
return frames
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
@@ -34,6 +55,8 @@ class SupervisedDataset(Dataset):
query_nums=64,
batch_vision=False,
max_length=2048,
video_max_slice_nums=2,
max_num_frames=1,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
@@ -45,17 +68,58 @@ class SupervisedDataset(Dataset):
self.query_nums=query_nums
self.batch_vision = batch_vision
self.max_length = max_length
# video config
self.video_slice_config = copy.deepcopy(slice_config)
self.video_slice_config['max_slice_nums'] = video_max_slice_nums
self.max_num_frames = max_num_frames
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
try:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
# default: sft image
use_image_id = True
slice_config = self.slice_config
if "image" in self.raw_data[i]:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
elif "video" in self.raw_data[i]:
if isinstance(self.raw_data[i]["video"], str):
frames = encode_video(self.raw_data[i]["video"], max_num_frames=self.max_num_frames)
image_names = []
images_dict = {}
for j, frame in enumerate(frames):
image_name = "<image_{:02d}>".format(j)
images_dict[image_name] = frame
image_names.append(image_name)
for j in range(len(self.raw_data[i]["conversations"])):
content = self.raw_data[i]["conversations"][j]['content']
self.raw_data[i]["conversations"][j]['content'] = content.replace("<video>", "".join(image_names))
elif isinstance(self.raw_data[i]["video"], Dict):
videos = self.raw_data[i]["video"]
images_dict = {}
video_names = {}
cnt = 0
for video_name in videos:
video_id = video_name.split("_")[-1].strip(">")
video = videos[video_name]
frames = encode_video(video, max_num_frames=self.max_num_frames)
image_names = []
for j, frame in enumerate(frames):
image_name = "<image_{:02d}>".format(cnt)
cnt += 1
images_dict[image_name] = frame
image_names.append(image_name)
for j in range(len(self.raw_data[i]["conversations"])):
content = self.raw_data[i]["conversations"][j]['content']
self.raw_data[i]["conversations"][j]['content'] = content.replace(video_name, "".join(image_names))
# video: modify config
slice_config = self.video_slice_config
use_image_id = False
ret = preprocess(
images_dict,
@@ -67,7 +131,8 @@ class SupervisedDataset(Dataset):
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
max_length=self.max_length
max_length=self.max_length,
use_image_id=use_image_id
)
ret = dict(
input_ids=ret["input_ids"],
@@ -318,6 +383,7 @@ def preprocess(
patch_size=14,
batch_vision=False,
max_length=2048,
use_image_id=True
):
"""
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
@@ -338,7 +404,7 @@ def preprocess(
use_image_id = False
if llm_type=='qwen2':
new_schema = True
use_image_id = True
use_image_id = use_image_id
image_placeholder_dict = {}
images = []
image_id_cnt = 0

View File

@@ -14,7 +14,7 @@ from accelerate.utils import DistributedType
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, AutoTokenizer, AutoProcessor
from transformers.integrations import deepspeed
from transformers import AutoModel, AutoTokenizer
@@ -53,6 +53,8 @@ class TrainingArguments(transformers.TrainingArguments):
llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
video_max_slice_nums: Optional[int] = field(default=2)
max_num_frames: Optional[int] = field(default=1)
@dataclass
@@ -92,6 +94,8 @@ def make_supervised_data_module(
query_nums=64,
batch_vision=False,
max_length=2048,
video_max_slice_nums=2,
max_num_frames=1,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = SupervisedDataset
@@ -109,6 +113,8 @@ def make_supervised_data_module(
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
video_max_slice_nums=video_max_slice_nums,
max_num_frames=max_num_frames,
)
if data_args.eval_data_path:
@@ -123,6 +129,8 @@ def make_supervised_data_module(
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
video_max_slice_nums=video_max_slice_nums,
max_num_frames=max_num_frames,
)
else:
eval_dataset = None
@@ -203,6 +211,9 @@ def train():
torch_dtype=compute_dtype,
device_map=device_map,
)
model.__class__.register_for_auto_class()
model.processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
@@ -276,6 +287,8 @@ def train():
query_nums=model.config.query_num,
batch_vision=batch_vision,
max_length=training_args.model_max_length,
video_max_slice_nums=training_args.video_max_slice_nums,
max_num_frames=training_args.max_num_frames,
)
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}

View File

@@ -20,30 +20,30 @@ If your input consists of a single image, you can use a single placeholder **\<i
[
{
"id": "0",
"image": 'path/to/image_0.jpg',
"image": "path/to/image_0.jpg",
"conversations": [
{
'role': 'user',
'content': '<image>\nHow many desserts are on the white plate?'
"role": "user",
"content": "<image>\nHow many desserts are on the white plate?"
},
{
'role': 'assistant',
'content': 'There are three desserts on the white plate.'
"role": "assistant",
"content": "There are three desserts on the white plate."
},
{
'role': 'user',
'content': 'What type of desserts are they?'
"role": "user",
"content": "What type of desserts are they?"
},
{
'role': 'assistant',
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.'
"role": "assistant",
"content": "The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them."
},
{
'role': 'user',
'content': 'What is the setting of the image?'},
"role": "user",
"content": "What is the setting of the image?"},
{
'role': 'assistant',
'content': 'The image is set on a table top with a plate containing the three desserts.'
"role": "assistant",
"content": "The image is set on a table top with a plate containing the three desserts."
},
]
},
@@ -91,6 +91,72 @@ If the total token count exceeds `max_length`, truncation will be applied. For m
```
</details>
#### Single Video Example
If your input consists of a single video, you can use a single placeholder **\<video\>** to indicate where the video should be inserted in the conversation.
<details>
<summary>
<b>Single video example (vl_finetune_video.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"video": "path/to/video_0.mp4",
"conversations": [
{
"role": "user",
"content": "<video>\nHow many desserts are on the white plate?"
},
{
"role": "assistant",
"content": "There are three desserts on the white plate."
}
]
}
]
```
</details>
#### Multiple Videos Example
For inputs containing multiple videos, utilize a dictionary where each key represents a unique placeholder (e.g., **\<video_00\>**, **\<video_01\**) with the corresponding video path as its value. These placeholders can then be used within the conversation to seamlessly insert videos at specific positions.
Additionally, to optimize resource management, especially when dealing with large batches of videos during training or inference, consider reducing `video_max_slice_nums` and `max_num_frames`. To minimize the number of tokens used per video, you can set `video_max_slice_nums=1` and `max_num_frames=1`, resulting in a single video being represented by 64 tokens.
If the total token count exceeds `max_length`, truncation will be applied. For multi-video supervised fine-tuning (SFT), it's recommended to set `MODEL_MAX_LENGTH=4096` in your script for better performance.
<details>
<summary>
<b>Multiple videos example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"video": {
"<video_00>": "path/to/video_0.mp4",
"<video_01>": "path/to/video_1.avi",
"<video_02>": "path/to/video_2.mp4",
"<video_03>": "path/to/video_3.avi"
},
"conversations": [
{
"role": "user",
"content": "How to create such text-only videos using CapCut?\n<video_00>\n<image_01>\n<video_01>\n<video_02>\n"
},
{
"role": "assistant",
"content": "To create a text-only video as shown in the videos, follow these steps in CapCut..."
}
]
}
]
```
</details>
### Full-parameter finetuning
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.

View File

@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
return (loss, logits, labels)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: int=None) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
@@ -245,6 +245,9 @@ class CPMTrainer(Trainer):
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
if self.model.processor is not None:
self.model.processor.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))