mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
fix: floor (#293)
This commit is contained in:
@@ -1,16 +1,17 @@
|
|||||||
import os
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor:
|
class AudioProcessor:
|
||||||
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
||||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
||||||
|
|
||||||
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
|
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
|
||||||
if not os.path.exists(wav_path):
|
if not os.path.exists(wav_path):
|
||||||
return None
|
return None
|
||||||
@@ -19,11 +20,11 @@ class AudioProcessor:
|
|||||||
# Split audio into 30s segments
|
# Split audio into 30s segments
|
||||||
segment_length = 30 * sampling_rate
|
segment_length = 30 * sampling_rate
|
||||||
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
|
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
audio_feature = self.feature_extractor(
|
audio_feature = self.feature_extractor(
|
||||||
segment,
|
segment,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
sampling_rate=sampling_rate
|
sampling_rate=sampling_rate
|
||||||
).input_features
|
).input_features
|
||||||
@@ -32,13 +33,13 @@ class AudioProcessor:
|
|||||||
features.append(audio_feature)
|
features.append(audio_feature)
|
||||||
|
|
||||||
return features, len(librosa_output)
|
return features, len(librosa_output)
|
||||||
|
|
||||||
def get_whisper_chunk(
|
def get_whisper_chunk(
|
||||||
self,
|
self,
|
||||||
whisper_input_features,
|
whisper_input_features,
|
||||||
device,
|
device,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
whisper,
|
whisper,
|
||||||
librosa_length,
|
librosa_length,
|
||||||
fps=25,
|
fps=25,
|
||||||
audio_padding_length_left=2,
|
audio_padding_length_left=2,
|
||||||
@@ -48,30 +49,30 @@ class AudioProcessor:
|
|||||||
whisper_feature = []
|
whisper_feature = []
|
||||||
# Process multiple 30s mel input features
|
# Process multiple 30s mel input features
|
||||||
for input_feature in whisper_input_features:
|
for input_feature in whisper_input_features:
|
||||||
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
|
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
|
||||||
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
|
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
|
||||||
whisper_feature.append(audio_feats)
|
whisper_feature.append(audio_feats)
|
||||||
|
|
||||||
whisper_feature = torch.cat(whisper_feature, dim=1)
|
whisper_feature = torch.cat(whisper_feature, dim=1)
|
||||||
# Trim the last segment to remove padding
|
# Trim the last segment to remove padding
|
||||||
sr = 16000
|
sr = 16000
|
||||||
audio_fps = 50
|
audio_fps = 50
|
||||||
fps = int(fps)
|
fps = int(fps)
|
||||||
whisper_idx_multiplier = audio_fps / fps
|
whisper_idx_multiplier = audio_fps / fps
|
||||||
num_frames = math.floor((librosa_length / sr)) * fps
|
num_frames = math.floor((librosa_length / sr) * fps)
|
||||||
actual_length = math.floor((librosa_length / sr)) * audio_fps
|
actual_length = math.floor((librosa_length / sr) * audio_fps)
|
||||||
whisper_feature = whisper_feature[:,:actual_length,...]
|
whisper_feature = whisper_feature[:,:actual_length,...]
|
||||||
|
|
||||||
# Calculate padding amount
|
# Calculate padding amount
|
||||||
padding_nums = math.floor(whisper_idx_multiplier)
|
padding_nums = math.floor(whisper_idx_multiplier)
|
||||||
# Add padding at start and end
|
# Add padding at start and end
|
||||||
whisper_feature = torch.cat([
|
whisper_feature = torch.cat([
|
||||||
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
|
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
|
||||||
whisper_feature,
|
whisper_feature,
|
||||||
# Add extra padding to prevent out of bounds
|
# Add extra padding to prevent out of bounds
|
||||||
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
|
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
|
||||||
], 1)
|
], 1)
|
||||||
|
|
||||||
audio_prompts = []
|
audio_prompts = []
|
||||||
for frame_index in range(num_frames):
|
for frame_index in range(num_frames):
|
||||||
try:
|
try:
|
||||||
@@ -86,7 +87,7 @@ class AudioProcessor:
|
|||||||
print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
|
print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
|
||||||
print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
|
print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
|
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
|
||||||
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
|
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
|
||||||
return audio_prompts
|
return audio_prompts
|
||||||
@@ -97,5 +98,4 @@ if __name__ == "__main__":
|
|||||||
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
|
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
|
||||||
print("Audio Feature shape:", audio_feature.shape)
|
print("Audio Feature shape:", audio_feature.shape)
|
||||||
print("librosa_feature_length:", librosa_feature_length)
|
print("librosa_feature_length:", librosa_feature_length)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user