fix: use torch.no_grad() in inference to prevent excessive memory usage (~30GB) with inference (#349)

This commit is contained in:
GaoLeiA
2025-07-02 16:38:56 +08:00
committed by GitHub
parent 8ca7d1884c
commit 26ca7c2c03

View File

@@ -235,6 +235,7 @@ class Avatar:
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame)
self.idx = self.idx + 1
@torch.no_grad()
def inference(self, audio_path, out_vid_name, fps, skip_save_images):
os.makedirs(self.avatar_path + '/tmp', exist_ok=True)
print("start inference")