fix: floor (#293)

This commit is contained in:
Chenghao Mou
2025-04-04 06:04:56 +01:00
committed by GitHub
parent 23d88dcfb9
commit e636166b85

View File

@@ -1,16 +1,17 @@
import os
import math import math
import os
import librosa import librosa
import numpy as np import numpy as np
import torch import torch
from einops import rearrange from einops import rearrange
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
class AudioProcessor: class AudioProcessor:
def __init__(self, feature_extractor_path="openai/whisper-tiny/"): def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path) self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None): def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
if not os.path.exists(wav_path): if not os.path.exists(wav_path):
return None return None
@@ -19,11 +20,11 @@ class AudioProcessor:
# Split audio into 30s segments # Split audio into 30s segments
segment_length = 30 * sampling_rate segment_length = 30 * sampling_rate
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)] segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
features = [] features = []
for segment in segments: for segment in segments:
audio_feature = self.feature_extractor( audio_feature = self.feature_extractor(
segment, segment,
return_tensors="pt", return_tensors="pt",
sampling_rate=sampling_rate sampling_rate=sampling_rate
).input_features ).input_features
@@ -32,13 +33,13 @@ class AudioProcessor:
features.append(audio_feature) features.append(audio_feature)
return features, len(librosa_output) return features, len(librosa_output)
def get_whisper_chunk( def get_whisper_chunk(
self, self,
whisper_input_features, whisper_input_features,
device, device,
weight_dtype, weight_dtype,
whisper, whisper,
librosa_length, librosa_length,
fps=25, fps=25,
audio_padding_length_left=2, audio_padding_length_left=2,
@@ -48,30 +49,30 @@ class AudioProcessor:
whisper_feature = [] whisper_feature = []
# Process multiple 30s mel input features # Process multiple 30s mel input features
for input_feature in whisper_input_features: for input_feature in whisper_input_features:
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
whisper_feature.append(audio_feats) whisper_feature.append(audio_feats)
whisper_feature = torch.cat(whisper_feature, dim=1) whisper_feature = torch.cat(whisper_feature, dim=1)
# Trim the last segment to remove padding # Trim the last segment to remove padding
sr = 16000 sr = 16000
audio_fps = 50 audio_fps = 50
fps = int(fps) fps = int(fps)
whisper_idx_multiplier = audio_fps / fps whisper_idx_multiplier = audio_fps / fps
num_frames = math.floor((librosa_length / sr)) * fps num_frames = math.floor((librosa_length / sr) * fps)
actual_length = math.floor((librosa_length / sr)) * audio_fps actual_length = math.floor((librosa_length / sr) * audio_fps)
whisper_feature = whisper_feature[:,:actual_length,...] whisper_feature = whisper_feature[:,:actual_length,...]
# Calculate padding amount # Calculate padding amount
padding_nums = math.floor(whisper_idx_multiplier) padding_nums = math.floor(whisper_idx_multiplier)
# Add padding at start and end # Add padding at start and end
whisper_feature = torch.cat([ whisper_feature = torch.cat([
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]), torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
whisper_feature, whisper_feature,
# Add extra padding to prevent out of bounds # Add extra padding to prevent out of bounds
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right]) torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
], 1) ], 1)
audio_prompts = [] audio_prompts = []
for frame_index in range(num_frames): for frame_index in range(num_frames):
try: try:
@@ -86,7 +87,7 @@ class AudioProcessor:
print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}") print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}") print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
exit() exit()
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384 audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w') audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
return audio_prompts return audio_prompts
@@ -97,5 +98,4 @@ if __name__ == "__main__":
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path) audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
print("Audio Feature shape:", audio_feature.shape) print("Audio Feature shape:", audio_feature.shape)
print("librosa_feature_length:", librosa_feature_length) print("librosa_feature_length:", librosa_feature_length)