mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 02:09:20 +08:00
support video sft and auto save and load all files
This commit is contained in:
@@ -7,6 +7,7 @@ import re
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from decord import VideoReader, cpu # pip install decord
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -20,6 +21,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
|
||||
|
||||
MAX_NUM_FRAMES=64
|
||||
def encode_video(video_path, max_num_frames=64):
|
||||
max_num_frames = min(max_num_frames, MAX_NUM_FRAMES)
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > max_num_frames:
|
||||
if max_num_frames==1:
|
||||
frame_idx = [frame_idx[len(frame_idx)//2]]
|
||||
else:
|
||||
frame_idx = uniform_sample(frame_idx, max_num_frames)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
return frames
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
@@ -34,6 +55,8 @@ class SupervisedDataset(Dataset):
|
||||
query_nums=64,
|
||||
batch_vision=False,
|
||||
max_length=2048,
|
||||
video_max_slice_nums=2,
|
||||
max_num_frames=1,
|
||||
):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
self.raw_data = raw_data
|
||||
@@ -45,17 +68,58 @@ class SupervisedDataset(Dataset):
|
||||
self.query_nums=query_nums
|
||||
self.batch_vision = batch_vision
|
||||
self.max_length = max_length
|
||||
# video config
|
||||
self.video_slice_config = copy.deepcopy(slice_config)
|
||||
self.video_slice_config['max_slice_nums'] = video_max_slice_nums
|
||||
self.max_num_frames = max_num_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.raw_data)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
try:
|
||||
if isinstance(self.raw_data[i]["image"], str):
|
||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
|
||||
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
|
||||
# default: sft image
|
||||
use_image_id = True
|
||||
slice_config = self.slice_config
|
||||
if "image" in self.raw_data[i]:
|
||||
if isinstance(self.raw_data[i]["image"], str):
|
||||
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
|
||||
elif isinstance(self.raw_data[i]["image"], Dict):
|
||||
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
|
||||
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
|
||||
elif "video" in self.raw_data[i]:
|
||||
if isinstance(self.raw_data[i]["video"], str):
|
||||
frames = encode_video(self.raw_data[i]["video"], max_num_frames=self.max_num_frames)
|
||||
image_names = []
|
||||
images_dict = {}
|
||||
for j, frame in enumerate(frames):
|
||||
image_name = "<image_{:02d}>".format(j)
|
||||
images_dict[image_name] = frame
|
||||
image_names.append(image_name)
|
||||
for j in range(len(self.raw_data[i]["conversations"])):
|
||||
content = self.raw_data[i]["conversations"][j]['content']
|
||||
self.raw_data[i]["conversations"][j]['content'] = content.replace("<video>", "".join(image_names))
|
||||
elif isinstance(self.raw_data[i]["video"], Dict):
|
||||
videos = self.raw_data[i]["video"]
|
||||
images_dict = {}
|
||||
video_names = {}
|
||||
cnt = 0
|
||||
for video_name in videos:
|
||||
video_id = video_name.split("_")[-1].strip(">")
|
||||
video = videos[video_name]
|
||||
frames = encode_video(video, max_num_frames=self.max_num_frames)
|
||||
image_names = []
|
||||
for j, frame in enumerate(frames):
|
||||
image_name = "<image_{:02d}>".format(cnt)
|
||||
cnt += 1
|
||||
images_dict[image_name] = frame
|
||||
image_names.append(image_name)
|
||||
for j in range(len(self.raw_data[i]["conversations"])):
|
||||
content = self.raw_data[i]["conversations"][j]['content']
|
||||
self.raw_data[i]["conversations"][j]['content'] = content.replace(video_name, "".join(image_names))
|
||||
# video: modify config
|
||||
slice_config = self.video_slice_config
|
||||
use_image_id = False
|
||||
|
||||
ret = preprocess(
|
||||
images_dict,
|
||||
@@ -67,7 +131,8 @@ class SupervisedDataset(Dataset):
|
||||
llm_type=self.llm_type,
|
||||
patch_size=self.patch_size,
|
||||
batch_vision=self.batch_vision,
|
||||
max_length=self.max_length
|
||||
max_length=self.max_length,
|
||||
use_image_id=use_image_id
|
||||
)
|
||||
ret = dict(
|
||||
input_ids=ret["input_ids"],
|
||||
@@ -318,6 +383,7 @@ def preprocess(
|
||||
patch_size=14,
|
||||
batch_vision=False,
|
||||
max_length=2048,
|
||||
use_image_id=True
|
||||
):
|
||||
"""
|
||||
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
||||
@@ -338,7 +404,7 @@ def preprocess(
|
||||
use_image_id = False
|
||||
if llm_type=='qwen2':
|
||||
new_schema = True
|
||||
use_image_id = True
|
||||
use_image_id = use_image_id
|
||||
image_placeholder_dict = {}
|
||||
images = []
|
||||
image_id_cnt = 0
|
||||
|
||||
Reference in New Issue
Block a user