From c5e82b1bc7727b057a4f9e9f1077a8a285c65a3e Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Sat, 11 Jan 2025 13:50:36 +0800 Subject: [PATCH] support video sft and auto save and load all files --- finetune/dataset.py | 80 ++++++++++++++++++++++++++++++++++---- finetune/finetune.py | 15 +++++++- finetune/readme.md | 92 +++++++++++++++++++++++++++++++++++++------- finetune/trainer.py | 5 ++- 4 files changed, 170 insertions(+), 22 deletions(-) diff --git a/finetune/dataset.py b/finetune/dataset.py index 92edb9b..885ae14 100644 --- a/finetune/dataset.py +++ b/finetune/dataset.py @@ -7,6 +7,7 @@ import re import random from dataclasses import dataclass, field from typing import Dict, List, Optional +from decord import VideoReader, cpu # pip install decord import numpy as np import torch @@ -20,6 +21,26 @@ logger = logging.getLogger(__name__) llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" +MAX_NUM_FRAMES=64 +def encode_video(video_path, max_num_frames=64): + max_num_frames = min(max_num_frames, MAX_NUM_FRAMES) + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + vr = VideoReader(video_path, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) # FPS + frame_idx = [i for i in range(0, len(vr), sample_fps)] + if len(frame_idx) > max_num_frames: + if max_num_frames==1: + frame_idx = [frame_idx[len(frame_idx)//2]] + else: + frame_idx = uniform_sample(frame_idx, max_num_frames) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + return frames + class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" @@ -34,6 +55,8 @@ class SupervisedDataset(Dataset): query_nums=64, batch_vision=False, max_length=2048, + video_max_slice_nums=2, + max_num_frames=1, ): super(SupervisedDataset, self).__init__() self.raw_data = raw_data @@ -45,17 +68,58 @@ class SupervisedDataset(Dataset): self.query_nums=query_nums self.batch_vision = batch_vision self.max_length = max_length + # video config + self.video_slice_config = copy.deepcopy(slice_config) + self.video_slice_config['max_slice_nums'] = video_max_slice_nums + self.max_num_frames = max_num_frames def __len__(self): return len(self.raw_data) def __getitem__(self, i) -> Dict[str, torch.Tensor]: try: - if isinstance(self.raw_data[i]["image"], str): - images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") } - elif isinstance(self.raw_data[i]["image"], Dict): - ### for multi-images input, the template for every image is , such as , - images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()} + # default: sft image + use_image_id = True + slice_config = self.slice_config + if "image" in self.raw_data[i]: + if isinstance(self.raw_data[i]["image"], str): + images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") } + elif isinstance(self.raw_data[i]["image"], Dict): + ### for multi-images input, the template for every image is , such as , + images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()} + elif "video" in self.raw_data[i]: + if isinstance(self.raw_data[i]["video"], str): + frames = encode_video(self.raw_data[i]["video"], max_num_frames=self.max_num_frames) + image_names = [] + images_dict = {} + for j, frame in enumerate(frames): + image_name = "".format(j) + images_dict[image_name] = frame + image_names.append(image_name) + for j in range(len(self.raw_data[i]["conversations"])): + content = self.raw_data[i]["conversations"][j]['content'] + self.raw_data[i]["conversations"][j]['content'] = content.replace("