mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
<enhance>: support using float16 in inference to speed up
This commit is contained in:
@@ -39,7 +39,10 @@ def get_video_fps(video_path):
|
||||
video.release()
|
||||
return fps
|
||||
|
||||
def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
|
||||
def datagen(whisper_chunks,
|
||||
vae_encode_latents,
|
||||
batch_size=8,
|
||||
delay_frame=0):
|
||||
whisper_batch, latent_batch = [], []
|
||||
for i, w in enumerate(whisper_chunks):
|
||||
idx = (i+delay_frame)%len(vae_encode_latents)
|
||||
@@ -48,14 +51,14 @@ def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
|
||||
latent_batch.append(latent)
|
||||
|
||||
if len(latent_batch) >= batch_size:
|
||||
whisper_batch = np.asarray(whisper_batch)
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
yield whisper_batch, latent_batch
|
||||
whisper_batch, latent_batch = [], []
|
||||
|
||||
# the last batch may smaller than batch size
|
||||
if len(latent_batch) > 0:
|
||||
whisper_batch = np.asarray(whisper_batch)
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch, latent_batch
|
||||
yield whisper_batch, latent_batch
|
||||
|
||||
Reference in New Issue
Block a user