<enhance>: support using float16 in inference to speed up

This commit is contained in:
czk32611
2024-04-27 14:26:50 +08:00
parent 2c52de01b4
commit 865a68c60e
6 changed files with 103 additions and 51 deletions

View File

@@ -37,11 +37,11 @@ class UNet():
self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(self.weights)
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights)
if use_float16:
self.model = self.model.half()
self.model.to(self.device)
if __name__ == "__main__":
unet = UNet()
unet = UNet()

View File

@@ -39,7 +39,10 @@ def get_video_fps(video_path):
video.release()
return fps
def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
def datagen(whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
@@ -48,14 +51,14 @@ def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = np.asarray(whisper_batch)
whisper_batch = np.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = np.asarray(whisper_batch)
whisper_batch = np.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
yield whisper_batch, latent_batch

View File

@@ -13,7 +13,11 @@ class Audio2Feature():
self.whisper_model_type = whisper_model_type
self.model = load_model(model_path) #
def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
def get_sliced_feature(self,
feature_array,
vid_idx,
audio_feat_length=[2,2],
fps=25):
"""
Get sliced features based on a given index
:param feature_array: