<enhance>: support using float16 in inference to speed up

This commit is contained in:
czk32611
2024-04-27 14:26:50 +08:00
parent 2c52de01b4
commit 865a68c60e
6 changed files with 103 additions and 51 deletions

View File

@@ -37,11 +37,11 @@ class UNet():
self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(self.weights)
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights)
if use_float16:
self.model = self.model.half()
self.model.to(self.device)
if __name__ == "__main__":
unet = UNet()
unet = UNet()