Skip to content

Commit e40f605

Browse files
committed
Moving test_save_and_load_low_cpu_mem_usage to ModelTesterMixin
1 parent e1c153b commit e40f605

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

tests/models/bert/test_modeling_bert.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -600,24 +600,6 @@ def test_model_from_pretrained(self):
600600
model = BertModel.from_pretrained(model_name)
601601
self.assertIsNotNone(model)
602602

603-
@slow
604-
def test_save_and_load_low_cpu_mem_usage(self):
605-
with tempfile.TemporaryDirectory() as tmpdirname:
606-
for model_class in self.all_model_classes:
607-
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
608-
model_to_save = model_class(config)
609-
610-
model_to_save.save_pretrained(tmpdirname)
611-
612-
model = model_class.from_pretrained(
613-
tmpdirname,
614-
low_cpu_mem_usage=True,
615-
)
616-
617-
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are
618-
# any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error.
619-
model.to(torch_device)
620-
621603
@slow
622604
@require_torch_accelerator
623605
def test_torchscript_device_change(self):

tests/test_modeling_common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,23 @@ class CopyClass(model_class):
435435
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
436436
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
437437

438+
def test_save_and_load_low_cpu_mem_usage(self):
439+
with tempfile.TemporaryDirectory() as tmpdirname:
440+
for model_class in self.all_model_classes:
441+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
442+
model_to_save = model_class(config)
443+
444+
model_to_save.save_pretrained(tmpdirname)
445+
446+
model = model_class.from_pretrained(
447+
tmpdirname,
448+
low_cpu_mem_usage=True,
449+
)
450+
451+
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are
452+
# any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error.
453+
model.to(torch_device)
454+
438455
def test_fast_init_context_manager(self):
439456
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
440457
class MyClass(PreTrainedModel):

0 commit comments

Comments
 (0)