import copy import json import logging import math import os import re import random from dataclasses import dataclass, field from typing import Dict, List, Optional from decord import VideoReader, cpu # pip install decord import numpy as np import torch from PIL import Image from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer import logging logger = logging.getLogger(__name__) llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" MAX_NUM_FRAMES=64 def encode_video(video_path, max_num_frames=64): max_num_frames = min(max_num_frames, MAX_NUM_FRAMES) def uniform_sample(l, n): gap = len(l) / n idxs = [int(i * gap + gap / 2) for i in range(n)] return [l[i] for i in idxs] vr = VideoReader(video_path, ctx=cpu(0)) sample_fps = round(vr.get_avg_fps() / 1) # FPS frame_idx = [i for i in range(0, len(vr), sample_fps)] if len(frame_idx) > max_num_frames: if max_num_frames==1: frame_idx = [frame_idx[len(frame_idx)//2]] else: frame_idx = uniform_sample(frame_idx, max_num_frames) frames = vr.get_batch(frame_idx).asnumpy() frames = [Image.fromarray(v.astype('uint8')) for v in frames] return frames class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( self, raw_data, transform, tokenizer, slice_config, llm_type="minicpm", patch_size=14, query_nums=64, batch_vision=False, max_length=2048, video_max_slice_nums=2, max_num_frames=1, ): super(SupervisedDataset, self).__init__() self.raw_data = raw_data self.tokenizer = tokenizer self.transform = transform self.slice_config = slice_config self.llm_type = llm_type self.patch_size = patch_size self.query_nums=query_nums self.batch_vision = batch_vision self.max_length = max_length # video config self.video_slice_config = copy.deepcopy(slice_config) self.video_slice_config['max_slice_nums'] = video_max_slice_nums self.max_num_frames = max_num_frames def __len__(self): return len(self.raw_data) def __getitem__(self, i) -> Dict[str, torch.Tensor]: try: # default: sft image use_image_id = True slice_config = self.slice_config if "image" in self.raw_data[i]: if isinstance(self.raw_data[i]["image"], str): images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") } elif isinstance(self.raw_data[i]["image"], Dict): ### for multi-images input, the template for every image is , such as , images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()} elif "video" in self.raw_data[i]: if isinstance(self.raw_data[i]["video"], str): frames = encode_video(self.raw_data[i]["video"], max_num_frames=self.max_num_frames) image_names = [] images_dict = {} for j, frame in enumerate(frames): image_name = "".format(j) images_dict[image_name] = frame image_names.append(image_name) for j in range(len(self.raw_data[i]["conversations"])): content = self.raw_data[i]["conversations"][j]['content'] self.raw_data[i]["conversations"][j]['content'] = content.replace("