<enhance>: support using float16 in inference to speed up

This commit is contained in:
czk32611
2024-04-27 14:26:50 +08:00
parent 2c52de01b4
commit 865a68c60e
6 changed files with 103 additions and 51 deletions

View File

@@ -39,7 +39,10 @@ def get_video_fps(video_path):
video.release()
return fps
def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
def datagen(whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
@@ -48,14 +51,14 @@ def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = np.asarray(whisper_batch)
whisper_batch = np.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = np.asarray(whisper_batch)
whisper_batch = np.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
yield whisper_batch, latent_batch