Merge pull request #708 from BUAADreamer/main

[Feature] Support video sft and fix some training bugs
This commit is contained in:
qianyu chen
2025-01-14 15:16:08 +08:00
committed by GitHub
4 changed files with 170 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ import re
import random import random
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional from typing import Dict, List, Optional
from decord import VideoReader, cpu # pip install decord
import numpy as np import numpy as np
import torch import torch
@@ -20,6 +21,26 @@ logger = logging.getLogger(__name__)
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
MAX_NUM_FRAMES=64
def encode_video(video_path, max_num_frames=64):
max_num_frames = min(max_num_frames, MAX_NUM_FRAMES)
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > max_num_frames:
if max_num_frames==1:
frame_idx = [frame_idx[len(frame_idx)//2]]
else:
frame_idx = uniform_sample(frame_idx, max_num_frames)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
return frames
class SupervisedDataset(Dataset): class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
@@ -34,6 +55,8 @@ class SupervisedDataset(Dataset):
query_nums=64, query_nums=64,
batch_vision=False, batch_vision=False,
max_length=2048, max_length=2048,
video_max_slice_nums=2,
max_num_frames=1,
): ):
super(SupervisedDataset, self).__init__() super(SupervisedDataset, self).__init__()
self.raw_data = raw_data self.raw_data = raw_data
@@ -45,17 +68,58 @@ class SupervisedDataset(Dataset):
self.query_nums=query_nums self.query_nums=query_nums
self.batch_vision = batch_vision self.batch_vision = batch_vision
self.max_length = max_length self.max_length = max_length
# video config
self.video_slice_config = copy.deepcopy(slice_config)
self.video_slice_config['max_slice_nums'] = video_max_slice_nums
self.max_num_frames = max_num_frames
def __len__(self): def __len__(self):
return len(self.raw_data) return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]: def __getitem__(self, i) -> Dict[str, torch.Tensor]:
try: try:
if isinstance(self.raw_data[i]["image"], str): # default: sft image
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") } use_image_id = True
elif isinstance(self.raw_data[i]["image"], Dict): slice_config = self.slice_config
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01> if "image" in self.raw_data[i]:
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()} if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
elif "video" in self.raw_data[i]:
if isinstance(self.raw_data[i]["video"], str):
frames = encode_video(self.raw_data[i]["video"], max_num_frames=self.max_num_frames)
image_names = []
images_dict = {}
for j, frame in enumerate(frames):
image_name = "<image_{:02d}>".format(j)
images_dict[image_name] = frame
image_names.append(image_name)
for j in range(len(self.raw_data[i]["conversations"])):
content = self.raw_data[i]["conversations"][j]['content']
self.raw_data[i]["conversations"][j]['content'] = content.replace("<video>", "".join(image_names))
elif isinstance(self.raw_data[i]["video"], Dict):
videos = self.raw_data[i]["video"]
images_dict = {}
video_names = {}
cnt = 0
for video_name in videos:
video_id = video_name.split("_")[-1].strip(">")
video = videos[video_name]
frames = encode_video(video, max_num_frames=self.max_num_frames)
image_names = []
for j, frame in enumerate(frames):
image_name = "<image_{:02d}>".format(cnt)
cnt += 1
images_dict[image_name] = frame
image_names.append(image_name)
for j in range(len(self.raw_data[i]["conversations"])):
content = self.raw_data[i]["conversations"][j]['content']
self.raw_data[i]["conversations"][j]['content'] = content.replace(video_name, "".join(image_names))
# video: modify config
slice_config = self.video_slice_config
use_image_id = False
ret = preprocess( ret = preprocess(
images_dict, images_dict,
@@ -67,7 +131,8 @@ class SupervisedDataset(Dataset):
llm_type=self.llm_type, llm_type=self.llm_type,
patch_size=self.patch_size, patch_size=self.patch_size,
batch_vision=self.batch_vision, batch_vision=self.batch_vision,
max_length=self.max_length max_length=self.max_length,
use_image_id=use_image_id
) )
ret = dict( ret = dict(
input_ids=ret["input_ids"], input_ids=ret["input_ids"],
@@ -318,6 +383,7 @@ def preprocess(
patch_size=14, patch_size=14,
batch_vision=False, batch_vision=False,
max_length=2048, max_length=2048,
use_image_id=True
): ):
""" """
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
@@ -338,7 +404,7 @@ def preprocess(
use_image_id = False use_image_id = False
if llm_type=='qwen2': if llm_type=='qwen2':
new_schema = True new_schema = True
use_image_id = True use_image_id = use_image_id
image_placeholder_dict = {} image_placeholder_dict = {}
images = [] images = []
image_id_cnt = 0 image_id_cnt = 0

View File

@@ -14,7 +14,7 @@ from accelerate.utils import DistributedType
from deepspeed import zero from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer, AutoProcessor
from transformers.integrations import deepspeed from transformers.integrations import deepspeed
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
@@ -53,6 +53,8 @@ class TrainingArguments(transformers.TrainingArguments):
llm_type: str = field(default="minicpm") llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False) use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9) max_slice_nums: Optional[int] = field(default=9)
video_max_slice_nums: Optional[int] = field(default=2)
max_num_frames: Optional[int] = field(default=1)
@dataclass @dataclass
@@ -92,6 +94,8 @@ def make_supervised_data_module(
query_nums=64, query_nums=64,
batch_vision=False, batch_vision=False,
max_length=2048, max_length=2048,
video_max_slice_nums=2,
max_num_frames=1,
) -> Dict: ) -> Dict:
"""Make dataset and collator for supervised fine-tuning.""" """Make dataset and collator for supervised fine-tuning."""
dataset_cls = SupervisedDataset dataset_cls = SupervisedDataset
@@ -109,6 +113,8 @@ def make_supervised_data_module(
query_nums=query_nums, query_nums=query_nums,
batch_vision=batch_vision, batch_vision=batch_vision,
max_length=max_length, max_length=max_length,
video_max_slice_nums=video_max_slice_nums,
max_num_frames=max_num_frames,
) )
if data_args.eval_data_path: if data_args.eval_data_path:
@@ -123,6 +129,8 @@ def make_supervised_data_module(
query_nums=query_nums, query_nums=query_nums,
batch_vision=batch_vision, batch_vision=batch_vision,
max_length=max_length, max_length=max_length,
video_max_slice_nums=video_max_slice_nums,
max_num_frames=max_num_frames,
) )
else: else:
eval_dataset = None eval_dataset = None
@@ -203,6 +211,9 @@ def train():
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
device_map=device_map, device_map=device_map,
) )
model.__class__.register_for_auto_class()
model.processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True model_args.model_name_or_path, trust_remote_code=True
@@ -276,6 +287,8 @@ def train():
query_nums=model.config.query_num, query_nums=model.config.query_num,
batch_vision=batch_vision, batch_vision=batch_vision,
max_length=training_args.model_max_length, max_length=training_args.model_max_length,
video_max_slice_nums=training_args.video_max_slice_nums,
max_num_frames=training_args.max_num_frames,
) )
training_args.gradient_checkpointing_kwargs={"use_reentrant":False} training_args.gradient_checkpointing_kwargs={"use_reentrant":False}

View File

@@ -20,30 +20,30 @@ If your input consists of a single image, you can use a single placeholder **\<i
[ [
{ {
"id": "0", "id": "0",
"image": 'path/to/image_0.jpg', "image": "path/to/image_0.jpg",
"conversations": [ "conversations": [
{ {
'role': 'user', "role": "user",
'content': '<image>\nHow many desserts are on the white plate?' "content": "<image>\nHow many desserts are on the white plate?"
}, },
{ {
'role': 'assistant', "role": "assistant",
'content': 'There are three desserts on the white plate.' "content": "There are three desserts on the white plate."
}, },
{ {
'role': 'user', "role": "user",
'content': 'What type of desserts are they?' "content": "What type of desserts are they?"
}, },
{ {
'role': 'assistant', "role": "assistant",
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.' "content": "The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them."
}, },
{ {
'role': 'user', "role": "user",
'content': 'What is the setting of the image?'}, "content": "What is the setting of the image?"},
{ {
'role': 'assistant', "role": "assistant",
'content': 'The image is set on a table top with a plate containing the three desserts.' "content": "The image is set on a table top with a plate containing the three desserts."
}, },
] ]
}, },
@@ -91,6 +91,72 @@ If the total token count exceeds `max_length`, truncation will be applied. For m
``` ```
</details> </details>
#### Single Video Example
If your input consists of a single video, you can use a single placeholder **\<video\>** to indicate where the video should be inserted in the conversation.
<details>
<summary>
<b>Single video example (vl_finetune_video.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"video": "path/to/video_0.mp4",
"conversations": [
{
"role": "user",
"content": "<video>\nHow many desserts are on the white plate?"
},
{
"role": "assistant",
"content": "There are three desserts on the white plate."
}
]
}
]
```
</details>
#### Multiple Videos Example
For inputs containing multiple videos, utilize a dictionary where each key represents a unique placeholder (e.g., **\<video_00\>**, **\<video_01\**) with the corresponding video path as its value. These placeholders can then be used within the conversation to seamlessly insert videos at specific positions.
Additionally, to optimize resource management, especially when dealing with large batches of videos during training or inference, consider reducing `video_max_slice_nums` and `max_num_frames`. To minimize the number of tokens used per video, you can set `video_max_slice_nums=1` and `max_num_frames=1`, resulting in a single video being represented by 64 tokens.
If the total token count exceeds `max_length`, truncation will be applied. For multi-video supervised fine-tuning (SFT), it's recommended to set `MODEL_MAX_LENGTH=4096` in your script for better performance.
<details>
<summary>
<b>Multiple videos example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"video": {
"<video_00>": "path/to/video_0.mp4",
"<video_01>": "path/to/video_1.avi",
"<video_02>": "path/to/video_2.mp4",
"<video_03>": "path/to/video_3.avi"
},
"conversations": [
{
"role": "user",
"content": "How to create such text-only videos using CapCut?\n<video_00>\n<image_01>\n<video_01>\n<video_02>\n"
},
{
"role": "assistant",
"content": "To create a text-only video as shown in the videos, follow these steps in CapCut..."
}
]
}
]
```
</details>
### Full-parameter finetuning ### Full-parameter finetuning
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts. Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.

View File

@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
return (loss, logits, labels) return (loss, logits, labels)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: int=None) -> torch.Tensor:
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
@@ -245,6 +245,9 @@ class CPMTrainer(Trainer):
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
if self.model.processor is not None:
self.model.processor.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))