File tree Expand file tree Collapse file tree 2 files changed +17
-18
lines changed Expand file tree Collapse file tree 2 files changed +17
-18
lines changed Original file line number Diff line number Diff line change @@ -600,24 +600,6 @@ def test_model_from_pretrained(self):
600
600
model = BertModel .from_pretrained (model_name )
601
601
self .assertIsNotNone (model )
602
602
603
- @slow
604
- def test_save_and_load_low_cpu_mem_usage (self ):
605
- with tempfile .TemporaryDirectory () as tmpdirname :
606
- for model_class in self .all_model_classes :
607
- config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
608
- model_to_save = model_class (config )
609
-
610
- model_to_save .save_pretrained (tmpdirname )
611
-
612
- model = model_class .from_pretrained (
613
- tmpdirname ,
614
- low_cpu_mem_usage = True ,
615
- )
616
-
617
- # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are
618
- # any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error.
619
- model .to (torch_device )
620
-
621
603
@slow
622
604
@require_torch_accelerator
623
605
def test_torchscript_device_change (self ):
Original file line number Diff line number Diff line change @@ -435,6 +435,23 @@ class CopyClass(model_class):
435
435
max_diff = (model_slow_init .state_dict ()[key ] - model_fast_init .state_dict ()[key ]).sum ().item ()
436
436
self .assertLessEqual (max_diff , 1e-3 , msg = f"{ key } not identical" )
437
437
438
+ def test_save_and_load_low_cpu_mem_usage (self ):
439
+ with tempfile .TemporaryDirectory () as tmpdirname :
440
+ for model_class in self .all_model_classes :
441
+ config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
442
+ model_to_save = model_class (config )
443
+
444
+ model_to_save .save_pretrained (tmpdirname )
445
+
446
+ model = model_class .from_pretrained (
447
+ tmpdirname ,
448
+ low_cpu_mem_usage = True ,
449
+ )
450
+
451
+ # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are
452
+ # any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error.
453
+ model .to (torch_device )
454
+
438
455
def test_fast_init_context_manager (self ):
439
456
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
440
457
class MyClass (PreTrainedModel ):
You can’t perform that action at this time.
0 commit comments