mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
Merge pull request #708 from BUAADreamer/main
[Feature] Support video sft and fix some training bugs
This commit is contained in:
@@ -7,6 +7,7 @@ import re
|
|||||||
import random
|
import random
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
from decord import VideoReader, cpu # pip install decord
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -20,6 +21,26 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
||||||
|
|
||||||
|
MAX_NUM_FRAMES=64
|
||||||
|
def encode_video(video_path, max_num_frames=64):
|
||||||
|
max_num_frames = min(max_num_frames, MAX_NUM_FRAMES)
|
||||||
|
def uniform_sample(l, n):
|
||||||
|
gap = len(l) / n
|
||||||
|
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||||
|
return [l[i] for i in idxs]
|
||||||
|
|
||||||
|
vr = VideoReader(video_path, ctx=cpu(0))
|
||||||
|
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||||
|
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||||
|
if len(frame_idx) > max_num_frames:
|
||||||
|
if max_num_frames==1:
|
||||||
|
frame_idx = [frame_idx[len(frame_idx)//2]]
|
||||||
|
else:
|
||||||
|
frame_idx = uniform_sample(frame_idx, max_num_frames)
|
||||||
|
frames = vr.get_batch(frame_idx).asnumpy()
|
||||||
|
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||||
|
return frames
|
||||||
|
|
||||||
class SupervisedDataset(Dataset):
|
class SupervisedDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
@@ -34,6 +55,8 @@ class SupervisedDataset(Dataset):
|
|||||||
query_nums=64,
|
query_nums=64,
|
||||||
batch_vision=False,
|
batch_vision=False,
|
||||||
max_length=2048,
|
max_length=2048,
|
||||||
|
video_max_slice_nums=2,
|
||||||
|
max_num_frames=1,
|
||||||
):
|
):
|
||||||
super(SupervisedDataset, self).__init__()
|
super(SupervisedDataset, self).__init__()
|
||||||
self.raw_data = raw_data
|
self.raw_data = raw_data
|
||||||
@@ -45,17 +68,58 @@ class SupervisedDataset(Dataset):
|
|||||||
self.query_nums=query_nums
|
self.query_nums=query_nums
|
||||||
self.batch_vision = batch_vision
|
self.batch_vision = batch_vision
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
# video config
|
||||||
|
self.video_slice_config = copy.deepcopy(slice_config)
|
||||||
|
self.video_slice_config['max_slice_nums'] = video_max_slice_nums
|
||||||
|
self.max_num_frames = max_num_frames
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.raw_data)
|
return len(self.raw_data)
|
||||||
|
|
||||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||||
try:
|
try:
|
||||||
if isinstance(self.raw_data[i]["image"], str):
|
# default: sft image
|
||||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
use_image_id = True
|
||||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
slice_config = self.slice_config
|
||||||
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
|
if "image" in self.raw_data[i]:
|
||||||
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
|
if isinstance(self.raw_data[i]["image"], str):
|
||||||
|
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||||
|
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||||
|
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
|
||||||
|
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
|
||||||
|
elif "video" in self.raw_data[i]:
|
||||||
|
if isinstance(self.raw_data[i]["video"], str):
|
||||||
|
frames = encode_video(self.raw_data[i]["video"], max_num_frames=self.max_num_frames)
|
||||||
|
image_names = []
|
||||||
|
images_dict = {}
|
||||||
|
for j, frame in enumerate(frames):
|
||||||
|
image_name = "<image_{:02d}>".format(j)
|
||||||
|
images_dict[image_name] = frame
|
||||||
|
image_names.append(image_name)
|
||||||
|
for j in range(len(self.raw_data[i]["conversations"])):
|
||||||
|
content = self.raw_data[i]["conversations"][j]['content']
|
||||||
|
self.raw_data[i]["conversations"][j]['content'] = content.replace("<video>", "".join(image_names))
|
||||||
|
elif isinstance(self.raw_data[i]["video"], Dict):
|
||||||
|
videos = self.raw_data[i]["video"]
|
||||||
|
images_dict = {}
|
||||||
|
video_names = {}
|
||||||
|
cnt = 0
|
||||||
|
for video_name in videos:
|
||||||
|
video_id = video_name.split("_")[-1].strip(">")
|
||||||
|
video = videos[video_name]
|
||||||
|
frames = encode_video(video, max_num_frames=self.max_num_frames)
|
||||||
|
image_names = []
|
||||||
|
for j, frame in enumerate(frames):
|
||||||
|
image_name = "<image_{:02d}>".format(cnt)
|
||||||
|
cnt += 1
|
||||||
|
images_dict[image_name] = frame
|
||||||
|
image_names.append(image_name)
|
||||||
|
for j in range(len(self.raw_data[i]["conversations"])):
|
||||||
|
content = self.raw_data[i]["conversations"][j]['content']
|
||||||
|
self.raw_data[i]["conversations"][j]['content'] = content.replace(video_name, "".join(image_names))
|
||||||
|
# video: modify config
|
||||||
|
slice_config = self.video_slice_config
|
||||||
|
use_image_id = False
|
||||||
|
|
||||||
ret = preprocess(
|
ret = preprocess(
|
||||||
images_dict,
|
images_dict,
|
||||||
@@ -67,7 +131,8 @@ class SupervisedDataset(Dataset):
|
|||||||
llm_type=self.llm_type,
|
llm_type=self.llm_type,
|
||||||
patch_size=self.patch_size,
|
patch_size=self.patch_size,
|
||||||
batch_vision=self.batch_vision,
|
batch_vision=self.batch_vision,
|
||||||
max_length=self.max_length
|
max_length=self.max_length,
|
||||||
|
use_image_id=use_image_id
|
||||||
)
|
)
|
||||||
ret = dict(
|
ret = dict(
|
||||||
input_ids=ret["input_ids"],
|
input_ids=ret["input_ids"],
|
||||||
@@ -318,6 +383,7 @@ def preprocess(
|
|||||||
patch_size=14,
|
patch_size=14,
|
||||||
batch_vision=False,
|
batch_vision=False,
|
||||||
max_length=2048,
|
max_length=2048,
|
||||||
|
use_image_id=True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
||||||
@@ -338,7 +404,7 @@ def preprocess(
|
|||||||
use_image_id = False
|
use_image_id = False
|
||||||
if llm_type=='qwen2':
|
if llm_type=='qwen2':
|
||||||
new_schema = True
|
new_schema = True
|
||||||
use_image_id = True
|
use_image_id = use_image_id
|
||||||
image_placeholder_dict = {}
|
image_placeholder_dict = {}
|
||||||
images = []
|
images = []
|
||||||
image_id_cnt = 0
|
image_id_cnt = 0
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from accelerate.utils import DistributedType
|
|||||||
from deepspeed import zero
|
from deepspeed import zero
|
||||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
||||||
|
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
||||||
from transformers.integrations import deepspeed
|
from transformers.integrations import deepspeed
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
@@ -53,6 +53,8 @@ class TrainingArguments(transformers.TrainingArguments):
|
|||||||
llm_type: str = field(default="minicpm")
|
llm_type: str = field(default="minicpm")
|
||||||
use_lora: Optional[bool] = field(default=False)
|
use_lora: Optional[bool] = field(default=False)
|
||||||
max_slice_nums: Optional[int] = field(default=9)
|
max_slice_nums: Optional[int] = field(default=9)
|
||||||
|
video_max_slice_nums: Optional[int] = field(default=2)
|
||||||
|
max_num_frames: Optional[int] = field(default=1)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -92,6 +94,8 @@ def make_supervised_data_module(
|
|||||||
query_nums=64,
|
query_nums=64,
|
||||||
batch_vision=False,
|
batch_vision=False,
|
||||||
max_length=2048,
|
max_length=2048,
|
||||||
|
video_max_slice_nums=2,
|
||||||
|
max_num_frames=1,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Make dataset and collator for supervised fine-tuning."""
|
"""Make dataset and collator for supervised fine-tuning."""
|
||||||
dataset_cls = SupervisedDataset
|
dataset_cls = SupervisedDataset
|
||||||
@@ -109,6 +113,8 @@ def make_supervised_data_module(
|
|||||||
query_nums=query_nums,
|
query_nums=query_nums,
|
||||||
batch_vision=batch_vision,
|
batch_vision=batch_vision,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
|
video_max_slice_nums=video_max_slice_nums,
|
||||||
|
max_num_frames=max_num_frames,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.eval_data_path:
|
if data_args.eval_data_path:
|
||||||
@@ -123,6 +129,8 @@ def make_supervised_data_module(
|
|||||||
query_nums=query_nums,
|
query_nums=query_nums,
|
||||||
batch_vision=batch_vision,
|
batch_vision=batch_vision,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
|
video_max_slice_nums=video_max_slice_nums,
|
||||||
|
max_num_frames=max_num_frames,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
@@ -203,6 +211,9 @@ def train():
|
|||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
)
|
)
|
||||||
|
model.__class__.register_for_auto_class()
|
||||||
|
|
||||||
|
model.processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path, trust_remote_code=True
|
model_args.model_name_or_path, trust_remote_code=True
|
||||||
@@ -276,6 +287,8 @@ def train():
|
|||||||
query_nums=model.config.query_num,
|
query_nums=model.config.query_num,
|
||||||
batch_vision=batch_vision,
|
batch_vision=batch_vision,
|
||||||
max_length=training_args.model_max_length,
|
max_length=training_args.model_max_length,
|
||||||
|
video_max_slice_nums=training_args.video_max_slice_nums,
|
||||||
|
max_num_frames=training_args.max_num_frames,
|
||||||
)
|
)
|
||||||
|
|
||||||
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
|
training_args.gradient_checkpointing_kwargs={"use_reentrant":False}
|
||||||
|
|||||||
@@ -20,30 +20,30 @@ If your input consists of a single image, you can use a single placeholder **\<i
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": "0",
|
"id": "0",
|
||||||
"image": 'path/to/image_0.jpg',
|
"image": "path/to/image_0.jpg",
|
||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
'role': 'user',
|
"role": "user",
|
||||||
'content': '<image>\nHow many desserts are on the white plate?'
|
"content": "<image>\nHow many desserts are on the white plate?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'role': 'assistant',
|
"role": "assistant",
|
||||||
'content': 'There are three desserts on the white plate.'
|
"content": "There are three desserts on the white plate."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'role': 'user',
|
"role": "user",
|
||||||
'content': 'What type of desserts are they?'
|
"content": "What type of desserts are they?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'role': 'assistant',
|
"role": "assistant",
|
||||||
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.'
|
"content": "The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'role': 'user',
|
"role": "user",
|
||||||
'content': 'What is the setting of the image?'},
|
"content": "What is the setting of the image?"},
|
||||||
{
|
{
|
||||||
'role': 'assistant',
|
"role": "assistant",
|
||||||
'content': 'The image is set on a table top with a plate containing the three desserts.'
|
"content": "The image is set on a table top with a plate containing the three desserts."
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -91,6 +91,72 @@ If the total token count exceeds `max_length`, truncation will be applied. For m
|
|||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### Single Video Example
|
||||||
|
If your input consists of a single video, you can use a single placeholder **\<video\>** to indicate where the video should be inserted in the conversation.
|
||||||
|
<details>
|
||||||
|
<summary>
|
||||||
|
<b>Single video example (vl_finetune_video.json) with 1 samples.</b>
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
```
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "0",
|
||||||
|
"video": "path/to/video_0.mp4",
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "<video>\nHow many desserts are on the white plate?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "There are three desserts on the white plate."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### Multiple Videos Example
|
||||||
|
For inputs containing multiple videos, utilize a dictionary where each key represents a unique placeholder (e.g., **\<video_00\>**, **\<video_01\**) with the corresponding video path as its value. These placeholders can then be used within the conversation to seamlessly insert videos at specific positions.
|
||||||
|
|
||||||
|
Additionally, to optimize resource management, especially when dealing with large batches of videos during training or inference, consider reducing `video_max_slice_nums` and `max_num_frames`. To minimize the number of tokens used per video, you can set `video_max_slice_nums=1` and `max_num_frames=1`, resulting in a single video being represented by 64 tokens.
|
||||||
|
|
||||||
|
If the total token count exceeds `max_length`, truncation will be applied. For multi-video supervised fine-tuning (SFT), it's recommended to set `MODEL_MAX_LENGTH=4096` in your script for better performance.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>
|
||||||
|
<b>Multiple videos example (vl_finetune_data.json) with 1 samples.</b>
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
```
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "0",
|
||||||
|
"video": {
|
||||||
|
"<video_00>": "path/to/video_0.mp4",
|
||||||
|
"<video_01>": "path/to/video_1.avi",
|
||||||
|
"<video_02>": "path/to/video_2.mp4",
|
||||||
|
"<video_03>": "path/to/video_3.avi"
|
||||||
|
},
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "How to create such text-only videos using CapCut?\n<video_00>\n<image_01>\n<video_01>\n<video_02>\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "To create a text-only video as shown in the videos, follow these steps in CapCut..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
### Full-parameter finetuning
|
### Full-parameter finetuning
|
||||||
|
|
||||||
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.
|
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class CPMTrainer(Trainer):
|
|||||||
|
|
||||||
return (loss, logits, labels)
|
return (loss, logits, labels)
|
||||||
|
|
||||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: int=None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs.
|
Perform a training step on a batch of inputs.
|
||||||
|
|
||||||
@@ -245,6 +245,9 @@ class CPMTrainer(Trainer):
|
|||||||
|
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
if self.model.processor is not None:
|
||||||
|
self.model.processor.save_pretrained(output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|||||||
Reference in New Issue
Block a user