fix some cuda related issue when run on M-Series Mac

This commit is contained in:
=
2024-04-05 22:03:28 +08:00
parent bc1379abad
commit 9a5212c8dd
2 changed files with 11 additions and 5 deletions

View File

@@ -36,11 +36,11 @@ class UNet():
unet_config = json.load(f)
self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384)
self.weights = torch.load(model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(self.weights)
if use_float16:
self.model = self.model.half()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
if __name__ == "__main__":

View File

@@ -16,8 +16,11 @@ class FaceParsing():
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
model_pth='./models/face-parse-bisent/79999_iter.pth'):
net = BiSeNet(resnet_path)
net.cuda()
net.load_state_dict(torch.load(model_pth))
if torch.cuda.is_available():
net.cuda()
net.load_state_dict(torch.load(model_pth))
else:
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
net.eval()
return net
@@ -35,7 +38,10 @@ class FaceParsing():
with torch.no_grad():
image = image.resize(size, Image.BILINEAR)
img = self.preprocess(image)
img = torch.unsqueeze(img, 0).cuda()
if torch.cuda.is_available():
img = torch.unsqueeze(img, 0).cuda()
else:
img = torch.unsqueeze(img, 0)
out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing>13)] = 0