import os import numpy as np import random from PIL import Image import torch from torch.utils.data import Dataset, ConcatDataset import torchvision.transforms as transforms from transformers import AutoFeatureExtractor import librosa import time import json import math from decord import AudioReader, VideoReader from decord.ndarray import cpu from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark from musetalk.data import audio from musetalk.utils.audio_utils import ensure_wav syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync class FaceDataset(Dataset): """Dataset class for loading and processing video data Each video can be represented as: - Concatenated frame images - '.mp4' or '.gif' files - Folder containing all frames """ def __init__(self, cfg, list_paths, root_path='./dataset/', repeats=None): # Initialize dataset paths meta_paths = [] if repeats is None: repeats = [1] * len(list_paths) assert len(repeats) == len(list_paths) # Load data list for list_path, repeat_time in zip(list_paths, repeats): with open(list_path, 'r') as f: num = 0 f.readline() # Skip header line for line in f.readlines(): line_info = line.strip() meta = line_info.split() meta = meta[0] meta_paths.extend([os.path.join(root_path, meta)] * repeat_time) num += 1 print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples') # Set basic attributes self.meta_paths = meta_paths self.root_path = root_path self.image_size = cfg['image_size'] self.min_face_size = cfg['min_face_size'] self.T = cfg['T'] self.sample_method = cfg['sample_method'] self.top_k_ratio = cfg['top_k_ratio'] self.max_attempts = 200 self.padding_pixel_mouth = cfg['padding_pixel_mouth'] # Cropping related parameters self.crop_type = cfg['crop_type'] self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean'] self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std'] self.random_margin_method = cfg['random_margin_method'] # Image transformations self.to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) self.pose_to_tensor = transforms.Compose([ transforms.ToTensor(), ]) # Feature extractor self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path']) self.contorl_face_min_size = cfg["contorl_face_min_size"] print("The sample method is: ", self.sample_method) print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size) def generate_random_value(self): """Generate random value Returns: float: Generated random value """ if self.random_margin_method == "uniform": random_value = np.random.uniform( self.jaw2edge_margin_mean - self.jaw2edge_margin_std, self.jaw2edge_margin_mean + self.jaw2edge_margin_std ) elif self.random_margin_method == "normal": random_value = np.random.normal( loc=self.jaw2edge_margin_mean, scale=self.jaw2edge_margin_std ) random_value = np.clip( random_value, self.jaw2edge_margin_mean - self.jaw2edge_margin_std, self.jaw2edge_margin_mean + self.jaw2edge_margin_std, ) else: raise ValueError(f"Invalid random margin method: {self.random_margin_method}") return max(0, random_value) def dynamic_margin_crop(self, img, original_bbox, extra_margin=None): """Dynamically crop image with dynamic margin Args: img: Input image original_bbox: Original bounding box extra_margin: Extra margin Returns: tuple: (x1, y1, x2, y2, extra_margin) """ if extra_margin is None: extra_margin = self.generate_random_value() w, h = img.size x1, y1, x2, y2 = original_bbox y2 = min(y2 + int(extra_margin), h) return x1, y1, x2, y2, extra_margin def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None): """Crop and resize image Args: img: Input image bbox: Bounding box crop_type: Type of cropping extra_margin: Extra margin Returns: tuple: (Processed image, extra_margin, mask_scaled_factor) """ mask_scaled_factor = 1. if crop_type == 'crop_resize': x1, y1, x2, y2 = bbox img = img.crop((x1, y1, x2, y2)) img = img.resize((self.image_size, self.image_size), Image.LANCZOS) elif crop_type == 'dynamic_margin_crop_resize': x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin) w_original, _ = img.size img = img.crop((x1, y1, x2, y2)) w_cropped, _ = img.size mask_scaled_factor = w_cropped / w_original img = img.resize((self.image_size, self.image_size), Image.LANCZOS) elif crop_type == 'resize': w, h = img.size scale = np.sqrt(self.image_size ** 2 / (h * w)) new_w = int(w * scale) / 64 * 64 new_h = int(h * scale) / 64 * 64 img = img.resize((new_w, new_h), Image.LANCZOS) return img, extra_margin, mask_scaled_factor def get_audio_file(self, wav_path, start_index): """Get audio file features Args: wav_path: Audio file path start_index: Starting index Returns: tuple: (Audio features, start index) """ if not os.path.exists(wav_path): return None wav_path_converted = ensure_wav(wav_path) audio_input_librosa, sampling_rate = librosa.load(wav_path_converted, sr=16000) assert sampling_rate == 16000 while start_index >= 25 * 30: audio_input = audio_input_librosa[16000*30:] start_index -= 25 * 30 if start_index + 2 * 25 >= 25 * 30: start_index -= 4 * 25 audio_input = audio_input_librosa[16000*4:16000*34] else: audio_input = audio_input_librosa[:16000*30] assert 2 * (start_index) >= 0 assert 2 * (start_index + 2 * 25) <= 1500 audio_input = self.feature_extractor( audio_input, return_tensors="pt", sampling_rate=sampling_rate ).input_features return audio_input, start_index def get_audio_file_mel(self, wav_path, start_index): """Get mel spectrogram of audio file Args: wav_path: Audio file path start_index: Starting index Returns: tuple: (Mel spectrogram, start index) """ if not os.path.exists(wav_path): return None wav_path_converted = ensure_wav(wav_path) audio_input_librosa, sampling_rate = librosa.load(wav_path_converted, sr=16000) assert sampling_rate == 16000 audio_mel = self.mel_feature_extractor(audio_input_librosa) return audio_mel, start_index def mel_feature_extractor(self, audio_input): """Extract mel spectrogram features Args: audio_input: Input audio Returns: ndarray: Mel spectrogram features """ orig_mel = audio.melspectrogram(audio_input) return orig_mel.T def crop_audio_window(self, spec, start_frame_num, fps=25): """Crop audio window Args: spec: Spectrogram start_frame_num: Starting frame number fps: Frames per second Returns: ndarray: Cropped spectrogram """ start_idx = int(80. * (start_frame_num / float(fps))) end_idx = start_idx + syncnet_mel_step_size return spec[start_idx: end_idx, :] def get_syncnet_input(self, video_path): """Get SyncNet input features Args: video_path: Video file path Returns: ndarray: SyncNet input features """ ar = AudioReader(video_path, sample_rate=16000) original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0)) return original_mel.T def get_resized_mouth_mask( self, img_resized, landmark_array, face_shape, padding_pixel_mouth=0, image_size=256, crop_margin=0 ): landmark_array = np.array(landmark_array) resized_landmark = resize_landmark( landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size) landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format min_x, min_y = np.min(landmark_array, axis=0) max_x, max_y = np.max(landmark_array, axis=0) min_x = min_x - padding_pixel_mouth max_x = max_x + padding_pixel_mouth # Calculate x-axis length and use it for y-axis width = max_x - min_x # Calculate old center point center_y = (max_y + min_y) / 2 # Determine new min_y and max_y based on width min_y = center_y - width / 4 max_y = center_y + width / 4 # Adjust mask position for dynamic crop, shift y-axis min_y = min_y - crop_margin max_y = max_y - crop_margin # Prevent out of bounds min_x = max(min_x, 0) min_y = max(min_y, 0) max_x = min(max_x, face_shape[0]) max_y = min(max_y, face_shape[1]) mask = np.zeros_like(np.array(img_resized)) mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255 return Image.fromarray(mask) def __len__(self): return 100000 def __getitem__(self, idx): attempts = 0 while attempts < self.max_attempts: try: meta_path = random.sample(self.meta_paths, k=1)[0] with open(meta_path, 'r') as f: meta_data = json.load(f) except Exception as e: print(f"meta file error:{meta_path}") print(e) attempts += 1 time.sleep(0.1) continue video_path = meta_data["mp4_path"] wav_path = meta_data["wav_path"] bbox_list = meta_data["face_list"] landmark_list = meta_data["landmark_list"] T = self.T s = 0 e = meta_data["frames"] len_valid_clip = e - s if len_valid_clip < T * 10: attempts += 1 print(f"video {video_path} has less than {T * 10} frames") continue try: cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0)) total_frames = len(cap) assert total_frames == len(landmark_list) assert total_frames == len(bbox_list) landmark_shape = np.array(landmark_list).shape if landmark_shape != (total_frames, 68, 2): attempts += 1 print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks continue except Exception as e: print(f"video file error:{video_path}") print(e) attempts += 1 time.sleep(0.1) continue shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates( landmark_list, bbox_list ) if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size: print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}") attempts += 1 continue step = 1 drive_idx_start = random.randint(s, e - T * step) drive_idx_list = list( range(drive_idx_start, drive_idx_start + T * step, step)) assert len(drive_idx_list) == T src_idx_list = [] list_index_out_of_range = False for drive_idx in drive_idx_list: src_idx = get_src_idx( drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio) if src_idx is None: list_index_out_of_range = True break src_idx = min(src_idx, e - 1) src_idx = max(src_idx, s) src_idx_list.append(src_idx) if list_index_out_of_range: attempts += 1 print(f"video {video_path} has invalid source index for drive frames") continue ref_face_valid_flag = True extra_margin = self.generate_random_value() # Get reference images ref_imgs = [] for src_idx in src_idx_list: imSrc = Image.fromarray(cap[src_idx].asnumpy()) bbox_s = bbox_list_union[src_idx] imSrc, _, _ = self.crop_resize_img( imSrc, bbox_s, self.crop_type, extra_margin=None ) if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size: ref_face_valid_flag = False break ref_imgs.append(imSrc) if not ref_face_valid_flag: attempts += 1 print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}") continue # Get target images and masks imSameIDs = [] bboxes = [] face_masks = [] face_mask_valid = True target_face_valid_flag = True for drive_idx in drive_idx_list: imSameID = Image.fromarray(cap[drive_idx].asnumpy()) bbox_s = bbox_list_union[drive_idx] imSameID, _ , mask_scaled_factor = self.crop_resize_img( imSameID, bbox_s, self.crop_type, extra_margin=extra_margin ) if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size: target_face_valid_flag = False break crop_margin = extra_margin * mask_scaled_factor face_mask = self.get_resized_mouth_mask( imSameID, shift_landmarks[drive_idx], face_shapes[drive_idx], self.padding_pixel_mouth, self.image_size, crop_margin=crop_margin ) if np.count_nonzero(face_mask) == 0: face_mask_valid = False break if face_mask.size[1] == 0 or face_mask.size[0] == 0: print(f"video {video_path} has invalid face mask size at frame {drive_idx}") face_mask_valid = False break imSameIDs.append(imSameID) bboxes.append(bbox_s) face_masks.append(face_mask) if not face_mask_valid: attempts += 1 print(f"video {video_path} has invalid face mask") continue if not target_face_valid_flag: attempts += 1 print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}") continue # Process audio features audio_offset = drive_idx_list[0] audio_step = step fps = 25.0 / step try: audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset) _, audio_offset = self.get_audio_file_mel(wav_path, audio_offset) audio_feature_mel = self.get_syncnet_input(video_path) except Exception as e: print(f"audio file error:{wav_path}") print(e) attempts += 1 time.sleep(0.1) continue mel = self.crop_audio_window(audio_feature_mel, audio_offset) if mel.shape[0] != syncnet_mel_step_size: attempts += 1 print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}") continue mel = torch.FloatTensor(mel.T).unsqueeze(0) # Build sample dictionary sample = dict( pixel_values_vid=torch.stack( [self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0), pixel_values_ref_img=torch.stack( [self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0), pixel_values_face_mask=torch.stack( [self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0), audio_feature=audio_feature[0], audio_offset=audio_offset, audio_step=audio_step, mel=mel, wav_path=wav_path, fps=fps, ) return sample raise ValueError("Unable to find a valid sample after maximum attempts.") class HDTFDataset(FaceDataset): """HDTF dataset class""" def __init__(self, cfg): root_path = './dataset/HDTF/meta' list_paths = [ './dataset/HDTF/train.txt', ] repeats = [10] super().__init__(cfg, list_paths, root_path, repeats) print('HDTFDataset: ', len(self)) class VFHQDataset(FaceDataset): """VFHQ dataset class""" def __init__(self, cfg): root_path = './dataset/VFHQ/meta' list_paths = [ './dataset/VFHQ/train.txt', ] repeats = [1] super().__init__(cfg, list_paths, root_path, repeats) print('VFHQDataset: ', len(self)) def PortraitDataset(cfg=None): """Return dataset based on configuration Args: cfg: Configuration dictionary Returns: Dataset: Combined dataset """ if cfg["dataset_key"] == "HDTF": return ConcatDataset([HDTFDataset(cfg)]) elif cfg["dataset_key"] == "VFHQ": return ConcatDataset([VFHQDataset(cfg)]) else: print("############ use all dataset ############ ") return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)]) if __name__ == '__main__': # Set random seeds for reproducibility seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Create dataset with configuration parameters dataset = PortraitDataset(cfg={ 'T': 1, # Number of frames to process at once 'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform" 'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both 'image_size': 256, # Size of processed images (height and width) 'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames 'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling 'contorl_face_min_size': True, # Whether to enforce minimum face size 'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask 'min_face_size': 200, # Minimum face size requirement for dataset 'whisper_path': "./models/whisper", # Path to Whisper model 'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping 'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping 'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize" }) print(len(dataset)) import torchvision os.makedirs('debug', exist_ok=True) for i in range(10): # Check 10 samples sample = dataset[0] print(f"processing {i}") # Get images and mask ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w) target_img = (sample['pixel_values_vid'] + 1.0) / 2 face_mask = sample['pixel_values_face_mask'] # Print dimension information print(f"ref_img shape: {ref_img.shape}") print(f"target_img shape: {target_img.shape}") print(f"face_mask shape: {face_mask.shape}") # Create visualization images b, c, h, w = ref_img.shape # Apply mask only to target image target_mask = face_mask # Keep reference image unchanged ref_with_mask = ref_img.clone() # Create mask overlay for target image target_with_mask = target_img.clone() target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target # Save original images, mask, and overlay results # First row: original images # Second row: mask # Third row: overlay effect concatenated_img = torch.cat(( ref_img, target_img, # Original images torch.zeros_like(ref_img), target_mask, # Mask (black for ref) ref_with_mask, target_with_mask # Overlay effect ), dim=3) torchvision.utils.save_image( concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2)