This commit is contained in:
aidenyzhang
2025-03-28 16:03:02 +08:00
parent 058f7ddc7f
commit db204311a5
46 changed files with 729 additions and 204 deletions

42
musetalk/utils/utils.py Normal file → Executable file
View File

@@ -15,13 +15,24 @@ from musetalk.whisper.audio2feature import Audio2Feature
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model():
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
unet = UNet(unet_config="./models/musetalk/musetalk.json",
model_path ="./models/musetalk/pytorch_model.bin")
def load_all_model(
unet_model_path="./models/musetalk/pytorch_model.bin",
vae_type="sd-vae-ft-mse",
unet_config="./models/musetalk/musetalk.json",
device=None,
):
vae = VAE(
model_path = f"./models/{vae_type}/",
)
print(f"load unet model from {unet_model_path}")
unet = UNet(
unet_config=unet_config,
model_path=unet_model_path,
device=device
)
pe = PositionalEncoding(d_model=384)
return audio_processor,vae,unet,pe
return vae, unet, pe
def get_file_type(video_path):
_, ext = os.path.splitext(video_path)
@@ -39,10 +50,13 @@ def get_video_fps(video_path):
video.release()
return fps
def datagen(whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0):
def datagen(
whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0,
device="cuda:0",
):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
@@ -51,14 +65,14 @@ def datagen(whisper_chunks,
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = np.stack(whisper_batch)
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = np.stack(whisper_batch)
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
yield whisper_batch.to(device), latent_batch.to(device)