<enhance>: support using float16 in inference to speed up

This commit is contained in:
czk32611
2024-04-27 14:26:50 +08:00
parent 2c52de01b4
commit 865a68c60e
6 changed files with 103 additions and 51 deletions

View File

@@ -16,12 +16,18 @@ from musetalk.utils.utils import load_all_model
import shutil
# load model weights
audio_processor,vae,unet,pe = load_all_model()
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
@torch.no_grad()
def main(args):
global pe
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for task_id in inference_config:
@@ -96,10 +102,11 @@ def main(args):
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
@@ -145,7 +152,10 @@ if __name__ == "__main__":
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
args = parser.parse_args()
main(args)