mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
<enhance>: support using float16 in inference to speed up
This commit is contained in:
@@ -37,11 +37,11 @@ class UNet():
|
||||
self.model = UNet2DConditionModel(**unet_config)
|
||||
self.pe = PositionalEncoding(d_model=384)
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(self.weights)
|
||||
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(weights)
|
||||
if use_float16:
|
||||
self.model = self.model.half()
|
||||
self.model.to(self.device)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unet = UNet()
|
||||
unet = UNet()
|
||||
|
||||
@@ -39,7 +39,10 @@ def get_video_fps(video_path):
|
||||
video.release()
|
||||
return fps
|
||||
|
||||
def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
|
||||
def datagen(whisper_chunks,
|
||||
vae_encode_latents,
|
||||
batch_size=8,
|
||||
delay_frame=0):
|
||||
whisper_batch, latent_batch = [], []
|
||||
for i, w in enumerate(whisper_chunks):
|
||||
idx = (i+delay_frame)%len(vae_encode_latents)
|
||||
@@ -48,14 +51,14 @@ def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
|
||||
latent_batch.append(latent)
|
||||
|
||||
if len(latent_batch) >= batch_size:
|
||||
whisper_batch = np.asarray(whisper_batch)
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
yield whisper_batch, latent_batch
|
||||
whisper_batch, latent_batch = [], []
|
||||
|
||||
# the last batch may smaller than batch size
|
||||
if len(latent_batch) > 0:
|
||||
whisper_batch = np.asarray(whisper_batch)
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch, latent_batch
|
||||
yield whisper_batch, latent_batch
|
||||
|
||||
@@ -13,7 +13,11 @@ class Audio2Feature():
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.model = load_model(model_path) #
|
||||
|
||||
def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
|
||||
def get_sliced_feature(self,
|
||||
feature_array,
|
||||
vid_idx,
|
||||
audio_feat_length=[2,2],
|
||||
fps=25):
|
||||
"""
|
||||
Get sliced features based on a given index
|
||||
:param feature_array:
|
||||
|
||||
Reference in New Issue
Block a user