147
147
PushInProgress ,
148
148
PushToHubMixin ,
149
149
can_return_loss ,
150
+ check_torch_load_is_safe ,
150
151
find_labels ,
151
152
is_accelerate_available ,
152
153
is_apex_available ,
@@ -2831,6 +2832,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
2831
2832
logger .warning (
2832
2833
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
2833
2834
)
2835
+ check_torch_load_is_safe ()
2834
2836
state_dict = torch .load (weights_file , map_location = "cpu" , weights_only = True )
2835
2837
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
2836
2838
state_dict ["_smp_is_partial" ] = False
@@ -2850,6 +2852,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
2850
2852
if self .args .save_safetensors and os .path .isfile (safe_weights_file ):
2851
2853
state_dict = safetensors .torch .load_file (safe_weights_file , device = "cpu" )
2852
2854
else :
2855
+ check_torch_load_is_safe ()
2853
2856
state_dict = torch .load (weights_file , map_location = "cpu" , weights_only = True )
2854
2857
2855
2858
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
@@ -2944,6 +2947,7 @@ def _load_best_model(self):
2944
2947
if self .args .save_safetensors and os .path .isfile (best_safe_model_path ):
2945
2948
state_dict = safetensors .torch .load_file (best_safe_model_path , device = "cpu" )
2946
2949
else :
2950
+ check_torch_load_is_safe ()
2947
2951
state_dict = torch .load (best_model_path , map_location = "cpu" , weights_only = True )
2948
2952
2949
2953
state_dict ["_smp_is_partial" ] = False
@@ -2999,6 +3003,7 @@ def _load_best_model(self):
2999
3003
if self .args .save_safetensors and os .path .isfile (best_safe_model_path ):
3000
3004
state_dict = safetensors .torch .load_file (best_safe_model_path , device = "cpu" )
3001
3005
else :
3006
+ check_torch_load_is_safe ()
3002
3007
state_dict = torch .load (best_model_path , map_location = "cpu" , weights_only = True )
3003
3008
3004
3009
# If the model is on the GPU, it still works!
@@ -3354,6 +3359,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
3354
3359
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
3355
3360
if not isinstance (self .lr_scheduler , DeepSpeedSchedulerWrapper ):
3356
3361
with warnings .catch_warnings (record = True ) as caught_warnings :
3362
+ check_torch_load_is_safe ()
3357
3363
self .lr_scheduler .load_state_dict (
3358
3364
torch .load (os .path .join (checkpoint , SCHEDULER_NAME ), weights_only = True )
3359
3365
)
@@ -3386,6 +3392,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
3386
3392
if is_torch_xla_available ():
3387
3393
# On TPU we have to take some extra precautions to properly load the states on the right device.
3388
3394
if self .is_fsdp_xla_v1_enabled :
3395
+ check_torch_load_is_safe ()
3389
3396
optimizer_state = torch .load (
3390
3397
os .path .join (
3391
3398
checkpoint , f"rank{ self .args .process_index } -of-{ self .args .world_size } -{ OPTIMIZER_NAME } "
@@ -3396,10 +3403,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
3396
3403
# We only need `optimizer` when resuming from checkpoint
3397
3404
optimizer_state = optimizer_state ["optimizer" ]
3398
3405
else :
3406
+ check_torch_load_is_safe ()
3399
3407
optimizer_state = torch .load (
3400
3408
os .path .join (checkpoint , OPTIMIZER_NAME ), map_location = "cpu" , weights_only = True
3401
3409
)
3402
3410
with warnings .catch_warnings (record = True ) as caught_warnings :
3411
+ check_torch_load_is_safe ()
3403
3412
lr_scheduler_state = torch .load (
3404
3413
os .path .join (checkpoint , SCHEDULER_NAME ), map_location = "cpu" , weights_only = True
3405
3414
)
@@ -3443,12 +3452,14 @@ def opt_load_hook(mod, opt):
3443
3452
** _get_fsdp_ckpt_kwargs (),
3444
3453
)
3445
3454
else :
3455
+ check_torch_load_is_safe ()
3446
3456
self .optimizer .load_state_dict (
3447
3457
torch .load (
3448
3458
os .path .join (checkpoint , OPTIMIZER_NAME ), map_location = map_location , weights_only = True
3449
3459
)
3450
3460
)
3451
3461
with warnings .catch_warnings (record = True ) as caught_warnings :
3462
+ check_torch_load_is_safe ()
3452
3463
self .lr_scheduler .load_state_dict (
3453
3464
torch .load (os .path .join (checkpoint , SCHEDULER_NAME ), weights_only = True )
3454
3465
)
@@ -3486,6 +3497,7 @@ def _load_scaler(self, checkpoint):
3486
3497
# Load in scaler states
3487
3498
if is_torch_xla_available ():
3488
3499
with warnings .catch_warnings (record = True ) as caught_warnings :
3500
+ check_torch_load_is_safe ()
3489
3501
scaler_state = torch .load (
3490
3502
os .path .join (checkpoint , SCALER_NAME ), map_location = "cpu" , weights_only = True
3491
3503
)
@@ -3494,6 +3506,7 @@ def _load_scaler(self, checkpoint):
3494
3506
self .accelerator .scaler .load_state_dict (scaler_state )
3495
3507
else :
3496
3508
with warnings .catch_warnings (record = True ) as caught_warnings :
3509
+ check_torch_load_is_safe ()
3497
3510
self .accelerator .scaler .load_state_dict (
3498
3511
torch .load (os .path .join (checkpoint , SCALER_NAME ), weights_only = True )
3499
3512
)
0 commit comments