diff --git a/musetalk/models/unet.py b/musetalk/models/unet.py index 02a9408..8968657 100755 --- a/musetalk/models/unet.py +++ b/musetalk/models/unet.py @@ -36,11 +36,11 @@ class UNet(): unet_config = json.load(f) self.model = UNet2DConditionModel(**unet_config) self.pe = PositionalEncoding(d_model=384) - self.weights = torch.load(model_path) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) self.model.load_state_dict(self.weights) if use_float16: self.model = self.model.half() - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) if __name__ == "__main__": diff --git a/musetalk/utils/face_parsing/__init__.py b/musetalk/utils/face_parsing/__init__.py index 5bfddba..fc963a3 100755 --- a/musetalk/utils/face_parsing/__init__.py +++ b/musetalk/utils/face_parsing/__init__.py @@ -16,8 +16,11 @@ class FaceParsing(): resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth', model_pth='./models/face-parse-bisent/79999_iter.pth'): net = BiSeNet(resnet_path) - net.cuda() - net.load_state_dict(torch.load(model_pth)) + if torch.cuda.is_available(): + net.cuda() + net.load_state_dict(torch.load(model_pth)) + else: + net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu'))) net.eval() return net @@ -35,7 +38,10 @@ class FaceParsing(): with torch.no_grad(): image = image.resize(size, Image.BILINEAR) img = self.preprocess(image) - img = torch.unsqueeze(img, 0).cuda() + if torch.cuda.is_available(): + img = torch.unsqueeze(img, 0).cuda() + else: + img = torch.unsqueeze(img, 0) out = self.net(img)[0] parsing = out.squeeze(0).cpu().numpy().argmax(0) parsing[np.where(parsing>13)] = 0