mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
fix some cuda related issue when run on M-Series Mac
This commit is contained in:
@@ -36,11 +36,11 @@ class UNet():
|
|||||||
unet_config = json.load(f)
|
unet_config = json.load(f)
|
||||||
self.model = UNet2DConditionModel(**unet_config)
|
self.model = UNet2DConditionModel(**unet_config)
|
||||||
self.pe = PositionalEncoding(d_model=384)
|
self.pe = PositionalEncoding(d_model=384)
|
||||||
self.weights = torch.load(model_path)
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||||
self.model.load_state_dict(self.weights)
|
self.model.load_state_dict(self.weights)
|
||||||
if use_float16:
|
if use_float16:
|
||||||
self.model = self.model.half()
|
self.model = self.model.half()
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -16,8 +16,11 @@ class FaceParsing():
|
|||||||
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||||
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||||
net = BiSeNet(resnet_path)
|
net = BiSeNet(resnet_path)
|
||||||
net.cuda()
|
if torch.cuda.is_available():
|
||||||
net.load_state_dict(torch.load(model_pth))
|
net.cuda()
|
||||||
|
net.load_state_dict(torch.load(model_pth))
|
||||||
|
else:
|
||||||
|
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
||||||
net.eval()
|
net.eval()
|
||||||
return net
|
return net
|
||||||
|
|
||||||
@@ -35,7 +38,10 @@ class FaceParsing():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
image = image.resize(size, Image.BILINEAR)
|
image = image.resize(size, Image.BILINEAR)
|
||||||
img = self.preprocess(image)
|
img = self.preprocess(image)
|
||||||
img = torch.unsqueeze(img, 0).cuda()
|
if torch.cuda.is_available():
|
||||||
|
img = torch.unsqueeze(img, 0).cuda()
|
||||||
|
else:
|
||||||
|
img = torch.unsqueeze(img, 0)
|
||||||
out = self.net(img)[0]
|
out = self.net(img)[0]
|
||||||
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||||
parsing[np.where(parsing>13)] = 0
|
parsing[np.where(parsing>13)] = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user