@@ -32,7 +32,7 @@ class ModelHook:
32
32
_is_stateful = False
33
33
34
34
def __init__ (self ):
35
- self .fn_ref : "FunctionReference " = None
35
+ self .fn_ref : "HookFunctionReference " = None
36
36
37
37
def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
38
38
r"""
@@ -101,12 +101,27 @@ def reset_state(self, module: torch.nn.Module):
101
101
return module
102
102
103
103
104
- class FunctionReference :
104
+ class HookFunctionReference :
105
105
def __init__ (self ) -> None :
106
+ """A container class that maintains mutable references to forward pass functions in a hook chain.
107
+
108
+ Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
109
+ entire forward pass structure.
110
+
111
+ Attributes:
112
+ pre_forward: A callable that processes inputs before the main forward pass.
113
+ post_forward: A callable that processes outputs after the main forward pass.
114
+ forward: The current forward function in the hook chain.
115
+ original_forward: The original forward function, stored when a hook provides a custom new_forward.
116
+
117
+ The class enables hook removal by allowing updates to the forward chain through reference modification rather
118
+ than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
119
+ be updated, preserving the execution order of the remaining hooks.
120
+ """
106
121
self .pre_forward = None
107
122
self .post_forward = None
108
- self .old_forward = None
109
- self .overwritten_forward = None
123
+ self .forward = None
124
+ self .original_forward = None
110
125
111
126
112
127
class HookRegistry :
@@ -125,24 +140,24 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
125
140
126
141
self ._module_ref = hook .initialize_hook (self ._module_ref )
127
142
128
- def create_new_forward (function_reference : FunctionReference ):
143
+ def create_new_forward (function_reference : HookFunctionReference ):
129
144
def new_forward (module , * args , ** kwargs ):
130
145
args , kwargs = function_reference .pre_forward (module , * args , ** kwargs )
131
- output = function_reference .old_forward (* args , ** kwargs )
146
+ output = function_reference .forward (* args , ** kwargs )
132
147
return function_reference .post_forward (module , output )
133
148
134
149
return new_forward
135
150
136
151
forward = self ._module_ref .forward
137
152
138
- fn_ref = FunctionReference ()
153
+ fn_ref = HookFunctionReference ()
139
154
fn_ref .pre_forward = hook .pre_forward
140
155
fn_ref .post_forward = hook .post_forward
141
- fn_ref .old_forward = forward
156
+ fn_ref .forward = forward
142
157
143
158
if hasattr (hook , "new_forward" ):
144
- fn_ref .overwritten_forward = forward
145
- fn_ref .old_forward = functools .update_wrapper (
159
+ fn_ref .original_forward = forward
160
+ fn_ref .forward = functools .update_wrapper (
146
161
functools .partial (hook .new_forward , self ._module_ref ), hook .new_forward
147
162
)
148
163
@@ -160,25 +175,28 @@ def get_hook(self, name: str) -> Optional[ModelHook]:
160
175
return self .hooks .get (name , None )
161
176
162
177
def remove_hook (self , name : str , recurse : bool = True ) -> None :
163
- num_hooks = len (self ._hook_order )
164
- if name in self .hooks .keys ():
165
- hook = self .hooks [name ]
166
- index = self ._hook_order .index (name )
167
- fn_ref = self ._fn_refs [index ]
168
-
169
- old_forward = fn_ref .old_forward
170
- if fn_ref .overwritten_forward is not None :
171
- old_forward = fn_ref .overwritten_forward
178
+ if name not in self .hooks .keys ():
179
+ logger .warning (f"hook: { name } was not found in HookRegistry" )
180
+ return
172
181
173
- if index == num_hooks - 1 :
174
- self ._module_ref .forward = old_forward
175
- else :
176
- self ._fn_refs [index + 1 ].old_forward = old_forward
177
-
178
- self ._module_ref = hook .deinitalize_hook (self ._module_ref )
179
- del self .hooks [name ]
180
- self ._hook_order .pop (index )
181
- self ._fn_refs .pop (index )
182
+ num_hooks = len (self ._hook_order )
183
+ hook = self .hooks [name ]
184
+ index = self ._hook_order .index (name )
185
+ fn_ref = self ._fn_refs [index ]
186
+
187
+ old_forward = fn_ref .forward
188
+ if fn_ref .original_forward is not None :
189
+ old_forward = fn_ref .original_forward
190
+
191
+ if index == num_hooks - 1 :
192
+ self ._module_ref .forward = old_forward
193
+ else :
194
+ self ._fn_refs [index + 1 ].forward = old_forward
195
+
196
+ self ._module_ref = hook .deinitalize_hook (self ._module_ref )
197
+ del self .hooks [name ]
198
+ self ._hook_order .pop (index )
199
+ self ._fn_refs .pop (index )
182
200
183
201
if recurse :
184
202
for module_name , module in self ._module_ref .named_modules ():
0 commit comments