Skip to content

Commit e884298

Browse files
committed
merge with utilities
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5f483f8 commit e884298

File tree

2 files changed

+97
-9
lines changed

2 files changed

+97
-9
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from accelerate.hooks import (
3939
AlignDevicesHook,
4040
add_hook_to_module,
41+
attach_align_device_hook,
4142
named_module_tensors,
4243
remove_hook_from_module,
4344
)
@@ -58,6 +59,7 @@
5859
set_module_tensor_to_device = None
5960
named_module_tensors = None
6061
dispatch_model = None
62+
attach_align_device_hook = None
6163

6264

6365
__all__ = [
@@ -458,21 +460,42 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
458460

459461

460462
@check_accelerate(fallback="error")
461-
def force_cpu_offload(module: torch.nn.Module, execution_device: torch.device):
463+
def force_cpu_offload(
464+
module: torch.nn.Module, execution_device: torch.device
465+
) -> torch.nn.Module:
466+
"""
467+
Force cpu offloading a module, primarily used for testing
468+
469+
:param module: module containing parameters to offload
470+
:param execution_device: execution device submodules
471+
:return: module with hooks to perform cpu offloading
472+
"""
473+
# edge case: there is a bug in `dispatch_model` which causes
474+
# the function to only work if the model contains submodules
475+
if next(module.children(), None) is None:
476+
attach_align_device_hook(
477+
module,
478+
execution_device=execution_device,
479+
offload=True,
480+
weights_map=module.state_dict(),
481+
tied_params_map={},
482+
)
483+
return module
484+
462485
device_map = {}
463486

464-
def dfs(name: List[str], module: torch.nn.Module):
487+
def collect_device_map(name: List[str], module: torch.nn.Module):
465488
if next(module.parameters(recurse=False), None) is not None:
466489
device_map[".".join(name)] = "cpu"
467490
return
468491

469492
else:
470493
for submodule_name, submodule in module.named_children():
471494
name.append(submodule_name)
472-
dfs(name, submodule)
495+
collect_device_map(name, submodule)
473496
name.pop()
474497

475-
dfs([], module)
498+
collect_device_map([], module)
476499

477500
return dispatch_model(
478501
module, device_map, main_device=execution_device, force_hooks=True

tests/test_utils/test_offload.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
align_modules,
1919
delete_offload_parameter,
2020
disable_hf_hook,
21+
force_cpu_offload,
2122
get_execution_device,
2223
has_offloaded_params,
24+
register_offload_module,
2325
register_offload_parameter,
2426
update_offload_parameter,
2527
)
@@ -37,9 +39,17 @@ def forward(self, x):
3739
return x * self.a + self.b
3840

3941

42+
class ExampleModel(torch.nn.Module):
43+
def __init__(self):
44+
super().__init__()
45+
self.linear = torch.nn.Linear(1, 2)
46+
47+
def forward(self, x):
48+
return self.linear(x)
49+
50+
4051
@requires_accelerate()
4152
def test_has_offloaded_params():
42-
from accelerate.big_modeling import cpu_offload_with_hook
4353
from accelerate.hooks import attach_align_device_hook, remove_hook_from_module
4454

4555
module = ExampleModule()
@@ -48,10 +58,6 @@ def test_has_offloaded_params():
4858
attach_align_device_hook(module, offload=False)
4959
assert not has_offloaded_params(module)
5060

51-
remove_hook_from_module(module)
52-
module, _ = cpu_offload_with_hook(module)
53-
assert not has_offloaded_params(module)
54-
5561
remove_hook_from_module(module)
5662
attach_align_device_hook(module, offload=True, weights_map=module.state_dict())
5763
assert has_offloaded_params(module)
@@ -334,3 +340,62 @@ def test_offload_to_weights_map():
334340
weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix)
335341
offload_to_weights_map(weights_map, name, new_value)
336342
assert weights_map[name] == new_value
343+
344+
345+
@requires_gpu
346+
@requires_accelerate()
347+
def test_register_offload_module():
348+
execution_device = torch.device("cuda")
349+
350+
# no offloading
351+
model = ExampleModel()
352+
child = torch.nn.Linear(2, 3)
353+
register_offload_module(model, "child", child)
354+
register_offload_module(model.linear, "child", child)
355+
assert child in model.children()
356+
assert child in model.linear.children()
357+
358+
# with offloading
359+
model = ExampleModel()
360+
child = torch.nn.Linear(2, 3)
361+
force_cpu_offload(model, execution_device)
362+
register_offload_module(model, "child", child)
363+
register_offload_module(model.linear, "child", child)
364+
assert child in model.children()
365+
assert child in model.linear.children()
366+
367+
# can run modules
368+
model(torch.empty(1))
369+
child(torch.empty(2, device=execution_device))
370+
371+
372+
@requires_gpu
373+
@requires_accelerate()
374+
def test_force_cpu_offload():
375+
execution_device = torch.device("cuda")
376+
377+
# single module
378+
module = torch.nn.Linear(1, 2)
379+
module = force_cpu_offload(module, execution_device)
380+
assert has_offloaded_params(module)
381+
assert module._hf_hook.offload
382+
assert module.weight.device == torch.device("meta")
383+
assert "weight" in module._hf_hook.weights_map
384+
assert module._hf_hook.tied_params_map is not None
385+
386+
# can run
387+
module(torch.empty(1, device=execution_device))
388+
389+
# model
390+
model = ExampleModel()
391+
model = force_cpu_offload(model, execution_device)
392+
assert not has_offloaded_params(model)
393+
394+
assert has_offloaded_params(model.linear)
395+
assert model.linear._hf_hook.offload
396+
assert model.linear.weight.device == torch.device("meta")
397+
assert "weight" in model.linear._hf_hook.weights_map
398+
assert model.linear._hf_hook.tied_params_map is not None
399+
400+
# can run
401+
model(torch.empty(1, device=execution_device))

0 commit comments

Comments
 (0)