18
18
align_modules ,
19
19
delete_offload_parameter ,
20
20
disable_hf_hook ,
21
+ force_cpu_offload ,
21
22
get_execution_device ,
22
23
has_offloaded_params ,
24
+ register_offload_module ,
23
25
register_offload_parameter ,
24
26
update_offload_parameter ,
25
27
)
@@ -37,9 +39,17 @@ def forward(self, x):
37
39
return x * self .a + self .b
38
40
39
41
42
+ class ExampleModel (torch .nn .Module ):
43
+ def __init__ (self ):
44
+ super ().__init__ ()
45
+ self .linear = torch .nn .Linear (1 , 2 )
46
+
47
+ def forward (self , x ):
48
+ return self .linear (x )
49
+
50
+
40
51
@requires_accelerate ()
41
52
def test_has_offloaded_params ():
42
- from accelerate .big_modeling import cpu_offload_with_hook
43
53
from accelerate .hooks import attach_align_device_hook , remove_hook_from_module
44
54
45
55
module = ExampleModule ()
@@ -48,10 +58,6 @@ def test_has_offloaded_params():
48
58
attach_align_device_hook (module , offload = False )
49
59
assert not has_offloaded_params (module )
50
60
51
- remove_hook_from_module (module )
52
- module , _ = cpu_offload_with_hook (module )
53
- assert not has_offloaded_params (module )
54
-
55
61
remove_hook_from_module (module )
56
62
attach_align_device_hook (module , offload = True , weights_map = module .state_dict ())
57
63
assert has_offloaded_params (module )
@@ -334,3 +340,62 @@ def test_offload_to_weights_map():
334
340
weights_map = PrefixedDataset (OffloadedWeightsLoader ({name : old_value }), prefix )
335
341
offload_to_weights_map (weights_map , name , new_value )
336
342
assert weights_map [name ] == new_value
343
+
344
+
345
+ @requires_gpu
346
+ @requires_accelerate ()
347
+ def test_register_offload_module ():
348
+ execution_device = torch .device ("cuda" )
349
+
350
+ # no offloading
351
+ model = ExampleModel ()
352
+ child = torch .nn .Linear (2 , 3 )
353
+ register_offload_module (model , "child" , child )
354
+ register_offload_module (model .linear , "child" , child )
355
+ assert child in model .children ()
356
+ assert child in model .linear .children ()
357
+
358
+ # with offloading
359
+ model = ExampleModel ()
360
+ child = torch .nn .Linear (2 , 3 )
361
+ force_cpu_offload (model , execution_device )
362
+ register_offload_module (model , "child" , child )
363
+ register_offload_module (model .linear , "child" , child )
364
+ assert child in model .children ()
365
+ assert child in model .linear .children ()
366
+
367
+ # can run modules
368
+ model (torch .empty (1 ))
369
+ child (torch .empty (2 , device = execution_device ))
370
+
371
+
372
+ @requires_gpu
373
+ @requires_accelerate ()
374
+ def test_force_cpu_offload ():
375
+ execution_device = torch .device ("cuda" )
376
+
377
+ # single module
378
+ module = torch .nn .Linear (1 , 2 )
379
+ module = force_cpu_offload (module , execution_device )
380
+ assert has_offloaded_params (module )
381
+ assert module ._hf_hook .offload
382
+ assert module .weight .device == torch .device ("meta" )
383
+ assert "weight" in module ._hf_hook .weights_map
384
+ assert module ._hf_hook .tied_params_map is not None
385
+
386
+ # can run
387
+ module (torch .empty (1 , device = execution_device ))
388
+
389
+ # model
390
+ model = ExampleModel ()
391
+ model = force_cpu_offload (model , execution_device )
392
+ assert not has_offloaded_params (model )
393
+
394
+ assert has_offloaded_params (model .linear )
395
+ assert model .linear ._hf_hook .offload
396
+ assert model .linear .weight .device == torch .device ("meta" )
397
+ assert "weight" in model .linear ._hf_hook .weights_map
398
+ assert model .linear ._hf_hook .tied_params_map is not None
399
+
400
+ # can run
401
+ model (torch .empty (1 , device = execution_device ))
0 commit comments