Skip to content

Commit 83d221f

Browse files
a-r-r-o-wDN6
andauthored
PAB Refactor (#10667)
* update * update * update --------- Co-authored-by: DN6 <[email protected]>
1 parent 3f3e26a commit 83d221f

File tree

3 files changed

+48
-32
lines changed

3 files changed

+48
-32
lines changed

src/diffusers/hooks/hooks.py

+46-28
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ModelHook:
3232
_is_stateful = False
3333

3434
def __init__(self):
35-
self.fn_ref: "FunctionReference" = None
35+
self.fn_ref: "HookFunctionReference" = None
3636

3737
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3838
r"""
@@ -101,12 +101,27 @@ def reset_state(self, module: torch.nn.Module):
101101
return module
102102

103103

104-
class FunctionReference:
104+
class HookFunctionReference:
105105
def __init__(self) -> None:
106+
"""A container class that maintains mutable references to forward pass functions in a hook chain.
107+
108+
Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
109+
entire forward pass structure.
110+
111+
Attributes:
112+
pre_forward: A callable that processes inputs before the main forward pass.
113+
post_forward: A callable that processes outputs after the main forward pass.
114+
forward: The current forward function in the hook chain.
115+
original_forward: The original forward function, stored when a hook provides a custom new_forward.
116+
117+
The class enables hook removal by allowing updates to the forward chain through reference modification rather
118+
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
119+
be updated, preserving the execution order of the remaining hooks.
120+
"""
106121
self.pre_forward = None
107122
self.post_forward = None
108-
self.old_forward = None
109-
self.overwritten_forward = None
123+
self.forward = None
124+
self.original_forward = None
110125

111126

112127
class HookRegistry:
@@ -125,24 +140,24 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
125140

126141
self._module_ref = hook.initialize_hook(self._module_ref)
127142

128-
def create_new_forward(function_reference: FunctionReference):
143+
def create_new_forward(function_reference: HookFunctionReference):
129144
def new_forward(module, *args, **kwargs):
130145
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
131-
output = function_reference.old_forward(*args, **kwargs)
146+
output = function_reference.forward(*args, **kwargs)
132147
return function_reference.post_forward(module, output)
133148

134149
return new_forward
135150

136151
forward = self._module_ref.forward
137152

138-
fn_ref = FunctionReference()
153+
fn_ref = HookFunctionReference()
139154
fn_ref.pre_forward = hook.pre_forward
140155
fn_ref.post_forward = hook.post_forward
141-
fn_ref.old_forward = forward
156+
fn_ref.forward = forward
142157

143158
if hasattr(hook, "new_forward"):
144-
fn_ref.overwritten_forward = forward
145-
fn_ref.old_forward = functools.update_wrapper(
159+
fn_ref.original_forward = forward
160+
fn_ref.forward = functools.update_wrapper(
146161
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
147162
)
148163

@@ -160,25 +175,28 @@ def get_hook(self, name: str) -> Optional[ModelHook]:
160175
return self.hooks.get(name, None)
161176

162177
def remove_hook(self, name: str, recurse: bool = True) -> None:
163-
num_hooks = len(self._hook_order)
164-
if name in self.hooks.keys():
165-
hook = self.hooks[name]
166-
index = self._hook_order.index(name)
167-
fn_ref = self._fn_refs[index]
168-
169-
old_forward = fn_ref.old_forward
170-
if fn_ref.overwritten_forward is not None:
171-
old_forward = fn_ref.overwritten_forward
178+
if name not in self.hooks.keys():
179+
logger.warning(f"hook: {name} was not found in HookRegistry")
180+
return
172181

173-
if index == num_hooks - 1:
174-
self._module_ref.forward = old_forward
175-
else:
176-
self._fn_refs[index + 1].old_forward = old_forward
177-
178-
self._module_ref = hook.deinitalize_hook(self._module_ref)
179-
del self.hooks[name]
180-
self._hook_order.pop(index)
181-
self._fn_refs.pop(index)
182+
num_hooks = len(self._hook_order)
183+
hook = self.hooks[name]
184+
index = self._hook_order.index(name)
185+
fn_ref = self._fn_refs[index]
186+
187+
old_forward = fn_ref.forward
188+
if fn_ref.original_forward is not None:
189+
old_forward = fn_ref.original_forward
190+
191+
if index == num_hooks - 1:
192+
self._module_ref.forward = old_forward
193+
else:
194+
self._fn_refs[index + 1].forward = old_forward
195+
196+
self._module_ref = hook.deinitalize_hook(self._module_ref)
197+
del self.hooks[name]
198+
self._hook_order.pop(index)
199+
self._fn_refs.pop(index)
182200

183201
if recurse:
184202
for module_name, module in self._module_ref.named_modules():

src/diffusers/hooks/pyramid_attention_broadcast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
162162
)
163163

164164
if should_compute_attention:
165-
output = self.fn_ref.overwritten_forward(*args, **kwargs)
165+
output = self.fn_ref.original_forward(*args, **kwargs)
166166
else:
167167
output = self.state.cache
168168

tests/hooks/test_hooks.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def new_forward(self, module, *args, **kwargs):
126126
logger.debug("SkipLayerHook new_forward")
127127
if self.skip_layer:
128128
return args[0]
129-
return self.fn_ref.overwritten_forward(*args, **kwargs)
129+
return self.fn_ref.original_forward(*args, **kwargs)
130130

131131
def post_forward(self, module, output):
132132
logger.debug("SkipLayerHook post_forward")
@@ -174,14 +174,12 @@ def test_hook_registry(self):
174174

175175
self.assertEqual(len(registry.hooks), 2)
176176
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
177-
self.assertEqual(len(registry._fn_refs), 2)
178177
self.assertEqual(registry_repr, expected_repr)
179178

180179
registry.remove_hook("add_hook")
181180

182181
self.assertEqual(len(registry.hooks), 1)
183182
self.assertEqual(registry._hook_order, ["multiply_hook"])
184-
self.assertEqual(len(registry._fn_refs), 1)
185183

186184
def test_stateful_hook(self):
187185
registry = HookRegistry.check_if_exists_or_initialize(self.model)

0 commit comments

Comments
 (0)