mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
<enhance>: support using float16 in inference to speed up
This commit is contained in:
@@ -16,12 +16,18 @@ from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
# load model weights
|
||||
audio_processor,vae,unet,pe = load_all_model()
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
for task_id in inference_config:
|
||||
@@ -96,10 +102,11 @@ def main(args):
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
|
||||
tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
|
||||
audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
@@ -145,7 +152,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user