support video sft and auto save and load all files

This commit is contained in:
fzc8578
2025-01-11 13:50:36 +08:00
parent 8464c94a7b
commit c5e82b1bc7
4 changed files with 170 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ import re
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from decord import VideoReader, cpu # pip install decord
import numpy as np
import torch
@@ -20,6 +21,26 @@ logger = logging.getLogger(__name__)
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
MAX_NUM_FRAMES=64
def encode_video(video_path, max_num_frames=64):
max_num_frames = min(max_num_frames, MAX_NUM_FRAMES)
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > max_num_frames:
if max_num_frames==1:
frame_idx = [frame_idx[len(frame_idx)//2]]
else:
frame_idx = uniform_sample(frame_idx, max_num_frames)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
return frames
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
@@ -34,6 +55,8 @@ class SupervisedDataset(Dataset):
query_nums=64,
batch_vision=False,
max_length=2048,
video_max_slice_nums=2,
max_num_frames=1,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
@@ -45,17 +68,58 @@ class SupervisedDataset(Dataset):
self.query_nums=query_nums
self.batch_vision = batch_vision
self.max_length = max_length
# video config
self.video_slice_config = copy.deepcopy(slice_config)
self.video_slice_config['max_slice_nums'] = video_max_slice_nums
self.max_num_frames = max_num_frames
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
try:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
# default: sft image
use_image_id = True
slice_config = self.slice_config
if "image" in self.raw_data[i]:
if isinstance(self.raw_data[i]["image"], str):
images_dict = { "<image>" : Image.open(self.raw_data[i]["image"]).convert("RGB") }
elif isinstance(self.raw_data[i]["image"], Dict):
### for multi-images input, the template for every image is <image_xx>, such as <image_00>, <image_01>
images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()}
elif "video" in self.raw_data[i]:
if isinstance(self.raw_data[i]["video"], str):
frames = encode_video(self.raw_data[i]["video"], max_num_frames=self.max_num_frames)
image_names = []
images_dict = {}
for j, frame in enumerate(frames):
image_name = "<image_{:02d}>".format(j)
images_dict[image_name] = frame
image_names.append(image_name)
for j in range(len(self.raw_data[i]["conversations"])):
content = self.raw_data[i]["conversations"][j]['content']
self.raw_data[i]["conversations"][j]['content'] = content.replace("<video>", "".join(image_names))
elif isinstance(self.raw_data[i]["video"], Dict):
videos = self.raw_data[i]["video"]
images_dict = {}
video_names = {}
cnt = 0
for video_name in videos:
video_id = video_name.split("_")[-1].strip(">")
video = videos[video_name]
frames = encode_video(video, max_num_frames=self.max_num_frames)
image_names = []
for j, frame in enumerate(frames):
image_name = "<image_{:02d}>".format(cnt)
cnt += 1
images_dict[image_name] = frame
image_names.append(image_name)
for j in range(len(self.raw_data[i]["conversations"])):
content = self.raw_data[i]["conversations"][j]['content']
self.raw_data[i]["conversations"][j]['content'] = content.replace(video_name, "".join(image_names))
# video: modify config
slice_config = self.video_slice_config
use_image_id = False
ret = preprocess(
images_dict,
@@ -67,7 +131,8 @@ class SupervisedDataset(Dataset):
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
max_length=self.max_length
max_length=self.max_length,
use_image_id=use_image_id
)
ret = dict(
input_ids=ret["input_ids"],
@@ -318,6 +383,7 @@ def preprocess(
patch_size=14,
batch_vision=False,
max_length=2048,
use_image_id=True
):
"""
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
@@ -338,7 +404,7 @@ def preprocess(
use_image_id = False
if llm_type=='qwen2':
new_schema = True
use_image_id = True
use_image_id = use_image_id
image_placeholder_dict = {}
images = []
image_id_cnt = 0