Skip to content

Commit de380fb

Browse files
committed
Fix model loading error
1 parent 6526409 commit de380fb

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

autoencoder/pixelvae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def load_pixelvae_model(weights_path, device, key):
282282
binning = HSVCube(32, 8, 16)
283283

284284
model = Decoder(binning.bins_per_channel)
285-
state_dict = torch.load(decryptedStream, weights_only=True)
285+
state_dict = torch.load(decryptedStream, map_location=device)
286286
model.load_state_dict(state_dict['state_dict'])
287287
model.eval()
288288
model = model.to(device)

0 commit comments

Comments
 (0)