mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
fix some cuda related issue when run on M-Series Mac
This commit is contained in:
@@ -16,8 +16,11 @@ class FaceParsing():
|
||||
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||
net = BiSeNet(resnet_path)
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(model_pth))
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(model_pth))
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
||||
net.eval()
|
||||
return net
|
||||
|
||||
@@ -35,7 +38,10 @@ class FaceParsing():
|
||||
with torch.no_grad():
|
||||
image = image.resize(size, Image.BILINEAR)
|
||||
img = self.preprocess(image)
|
||||
img = torch.unsqueeze(img, 0).cuda()
|
||||
if torch.cuda.is_available():
|
||||
img = torch.unsqueeze(img, 0).cuda()
|
||||
else:
|
||||
img = torch.unsqueeze(img, 0)
|
||||
out = self.net(img)[0]
|
||||
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||
parsing[np.where(parsing>13)] = 0
|
||||
|
||||
Reference in New Issue
Block a user