mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
Merge pull request #5 from hotea/main
fix some cuda related issue when run on M-Series Mac with cpu
This commit is contained in:
@@ -36,11 +36,11 @@ class UNet():
|
||||
unet_config = json.load(f)
|
||||
self.model = UNet2DConditionModel(**unet_config)
|
||||
self.pe = PositionalEncoding(d_model=384)
|
||||
self.weights = torch.load(model_path)
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(self.weights)
|
||||
if use_float16:
|
||||
self.model = self.model.half()
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model.to(self.device)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -16,8 +16,11 @@ class FaceParsing():
|
||||
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||
net = BiSeNet(resnet_path)
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(model_pth))
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(model_pth))
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
||||
net.eval()
|
||||
return net
|
||||
|
||||
@@ -35,7 +38,10 @@ class FaceParsing():
|
||||
with torch.no_grad():
|
||||
image = image.resize(size, Image.BILINEAR)
|
||||
img = self.preprocess(image)
|
||||
img = torch.unsqueeze(img, 0).cuda()
|
||||
if torch.cuda.is_available():
|
||||
img = torch.unsqueeze(img, 0).cuda()
|
||||
else:
|
||||
img = torch.unsqueeze(img, 0)
|
||||
out = self.net(img)[0]
|
||||
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||
parsing[np.where(parsing>13)] = 0
|
||||
|
||||
Reference in New Issue
Block a user