Skip to content

Commit ddeb0c1

Browse files
tohtanaloadams
andauthored
Fix patch for parameter partitioning in zero.Init() (#6388)
This PR fixes an issue addressed in #5921. With this change, we only apply the patch for parameter partitioning to classes that have `__init__` so that we can avoid applying the patch multiple times. The class that does not have `__init__` now uses its superclass's one. So this PR also applies the patch to the root class, `torch.nn.modules.module.Module`. Thanks @VeryLazyBoy for the report and initial solution. --------- Co-authored-by: Logan Adams <[email protected]>
1 parent 9d17116 commit ddeb0c1

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

deepspeed/runtime/zero/partition_parameters.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def new_tensor(cls, *args, **kwargs) -> Tensor:
262262

263263

264264
# https://stackoverflow.com/a/63851681/9201239
265-
def get_all_subclasses(cls):
265+
def get_all_subclasses(cls, include_root=True):
266266
subclass_list = []
267267

268268
def recurse(cl):
@@ -272,7 +272,10 @@ def recurse(cl):
272272

273273
recurse(cls)
274274

275-
return set(subclass_list)
275+
ret = set(subclass_list)
276+
if include_root:
277+
ret.add(cls)
278+
return ret
276279

277280

278281
@instrument_w_nvtx
@@ -465,11 +468,13 @@ def wrapper(*args, **kwargs):
465468
return wrapper
466469

467470
def _enable_class_apply(cls):
468-
cls._old_apply_of_skip_init_hook = cls._apply
469-
cls._apply = partition_after_empty_init(cls._apply)
471+
if '_apply' in cls.__dict__:
472+
cls._old_apply_of_skip_init_hook = cls._apply
473+
cls._apply = partition_after_empty_init(cls._apply)
470474

471475
def _disable_class_apply(cls):
472-
cls._apply = cls._old_apply_of_skip_init_hook
476+
if hasattr(cls, '_old_apply_of_skip_init_hook'):
477+
cls._apply = cls._old_apply_of_skip_init_hook
473478

474479
# add hooks for to_empty: apply_(empty_like)
475480
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
@@ -522,12 +527,14 @@ def wrapper(module, *args, **kwargs):
522527
return wrapper
523528

524529
def _enable_class(cls):
525-
cls._old_init = cls.__init__
526-
cls.__init__ = partition_after(cls.__init__)
530+
if '__init__' in cls.__dict__:
531+
cls._old_init = cls.__init__
532+
cls.__init__ = partition_after(cls.__init__)
527533

528534
def _init_subclass(cls, **kwargs):
529-
cls._old_init = cls.__init__
530-
cls.__init__ = partition_after(cls.__init__)
535+
if '__init__' in cls.__dict__:
536+
cls._old_init = cls.__init__
537+
cls.__init__ = partition_after(cls.__init__)
531538

532539
# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
533540
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
@@ -567,7 +574,8 @@ def unpatch_init_and_builtins(self):
567574
if self.patched:
568575

569576
def _disable_class(cls):
570-
cls.__init__ = cls._old_init
577+
if hasattr(cls, '_old_init'):
578+
cls.__init__ = cls._old_init
571579

572580
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
573581
_disable_class(subclass)

0 commit comments

Comments
 (0)