This commit is contained in:
aidenyzhang
2025-03-28 16:03:02 +08:00
parent 058f7ddc7f
commit db204311a5
46 changed files with 729 additions and 204 deletions

View File

@@ -31,12 +31,16 @@ class UNet():
unet_config,
model_path,
use_float16=False,
device=None
):
with open(unet_config, 'r') as f:
unet_config = json.load(f)
self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device != None:
self.device = device
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights)
if use_float16: