mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 18:09:19 +08:00
v1.5
This commit is contained in:
@@ -31,12 +31,16 @@ class UNet():
|
||||
unet_config,
|
||||
model_path,
|
||||
use_float16=False,
|
||||
device=None
|
||||
):
|
||||
with open(unet_config, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
self.model = UNet2DConditionModel(**unet_config)
|
||||
self.pe = PositionalEncoding(d_model=384)
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if device != None:
|
||||
self.device = device
|
||||
else:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(weights)
|
||||
if use_float16:
|
||||
|
||||
Reference in New Issue
Block a user