Skip to content

Commit 3f3e26a

Browse files
committed
add test suite for hooks
1 parent 8c74a7a commit 3f3e26a

File tree

1 file changed

+384
-0
lines changed

1 file changed

+384
-0
lines changed

tests/hooks/test_hooks.py

+384
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
# Copyright 2024 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import gc
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers.hooks import HookRegistry, ModelHook
21+
from diffusers.training_utils import free_memory
22+
from diffusers.utils.logging import get_logger
23+
from diffusers.utils.testing_utils import CaptureLogger, torch_device
24+
25+
26+
logger = get_logger(__name__) # pylint: disable=invalid-name
27+
28+
29+
class DummyBlock(torch.nn.Module):
30+
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
31+
super().__init__()
32+
33+
self.proj_in = torch.nn.Linear(in_features, hidden_features)
34+
self.activation = torch.nn.ReLU()
35+
self.proj_out = torch.nn.Linear(hidden_features, out_features)
36+
37+
def forward(self, x: torch.Tensor) -> torch.Tensor:
38+
x = self.proj_in(x)
39+
x = self.activation(x)
40+
x = self.proj_out(x)
41+
return x
42+
43+
44+
class DummyModel(torch.nn.Module):
45+
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
46+
super().__init__()
47+
48+
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
49+
self.activation = torch.nn.ReLU()
50+
self.blocks = torch.nn.ModuleList(
51+
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
52+
)
53+
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
54+
55+
def forward(self, x: torch.Tensor) -> torch.Tensor:
56+
x = self.linear_1(x)
57+
x = self.activation(x)
58+
for block in self.blocks:
59+
x = block(x)
60+
x = self.linear_2(x)
61+
return x
62+
63+
64+
class AddHook(ModelHook):
65+
def __init__(self, value: int):
66+
super().__init__()
67+
self.value = value
68+
69+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
70+
logger.debug("AddHook pre_forward")
71+
args = ((x + self.value) if torch.is_tensor(x) else x for x in args)
72+
return args, kwargs
73+
74+
def post_forward(self, module, output):
75+
logger.debug("AddHook post_forward")
76+
return output
77+
78+
79+
class MultiplyHook(ModelHook):
80+
def __init__(self, value: int):
81+
super().__init__()
82+
self.value = value
83+
84+
def pre_forward(self, module, *args, **kwargs):
85+
logger.debug("MultiplyHook pre_forward")
86+
args = ((x * self.value) if torch.is_tensor(x) else x for x in args)
87+
return args, kwargs
88+
89+
def post_forward(self, module, output):
90+
logger.debug("MultiplyHook post_forward")
91+
return output
92+
93+
def __repr__(self):
94+
return f"MultiplyHook(value={self.value})"
95+
96+
97+
class StatefulAddHook(ModelHook):
98+
_is_stateful = True
99+
100+
def __init__(self, value: int):
101+
super().__init__()
102+
self.value = value
103+
self.increment = 0
104+
105+
def pre_forward(self, module, *args, **kwargs):
106+
logger.debug("StatefulAddHook pre_forward")
107+
add_value = self.value + self.increment
108+
self.increment += 1
109+
args = ((x + add_value) if torch.is_tensor(x) else x for x in args)
110+
return args, kwargs
111+
112+
def reset_state(self, module):
113+
self.increment = 0
114+
115+
116+
class SkipLayerHook(ModelHook):
117+
def __init__(self, skip_layer: bool):
118+
super().__init__()
119+
self.skip_layer = skip_layer
120+
121+
def pre_forward(self, module, *args, **kwargs):
122+
logger.debug("SkipLayerHook pre_forward")
123+
return args, kwargs
124+
125+
def new_forward(self, module, *args, **kwargs):
126+
logger.debug("SkipLayerHook new_forward")
127+
if self.skip_layer:
128+
return args[0]
129+
return self.fn_ref.overwritten_forward(*args, **kwargs)
130+
131+
def post_forward(self, module, output):
132+
logger.debug("SkipLayerHook post_forward")
133+
return output
134+
135+
136+
class HookTests(unittest.TestCase):
137+
in_features = 4
138+
hidden_features = 8
139+
out_features = 4
140+
num_layers = 2
141+
142+
def setUp(self):
143+
params = self.get_module_parameters()
144+
self.model = DummyModel(**params)
145+
self.model.to(torch_device)
146+
147+
def tearDown(self):
148+
super().tearDown()
149+
150+
del self.model
151+
gc.collect()
152+
free_memory()
153+
154+
def get_module_parameters(self):
155+
return {
156+
"in_features": self.in_features,
157+
"hidden_features": self.hidden_features,
158+
"out_features": self.out_features,
159+
"num_layers": self.num_layers,
160+
}
161+
162+
def get_generator(self):
163+
return torch.manual_seed(0)
164+
165+
def test_hook_registry(self):
166+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
167+
registry.register_hook(AddHook(1), "add_hook")
168+
registry.register_hook(MultiplyHook(2), "multiply_hook")
169+
170+
registry_repr = repr(registry)
171+
expected_repr = (
172+
"HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")"
173+
)
174+
175+
self.assertEqual(len(registry.hooks), 2)
176+
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
177+
self.assertEqual(len(registry._fn_refs), 2)
178+
self.assertEqual(registry_repr, expected_repr)
179+
180+
registry.remove_hook("add_hook")
181+
182+
self.assertEqual(len(registry.hooks), 1)
183+
self.assertEqual(registry._hook_order, ["multiply_hook"])
184+
self.assertEqual(len(registry._fn_refs), 1)
185+
186+
def test_stateful_hook(self):
187+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
188+
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
189+
190+
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
191+
192+
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
193+
num_repeats = 3
194+
195+
for i in range(num_repeats):
196+
result = self.model(input)
197+
if i == 0:
198+
output1 = result
199+
200+
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
201+
202+
registry.reset_stateful_hooks()
203+
output2 = self.model(input)
204+
205+
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
206+
self.assertTrue(torch.allclose(output1, output2))
207+
208+
def test_inference(self):
209+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
210+
registry.register_hook(AddHook(1), "add_hook")
211+
registry.register_hook(MultiplyHook(2), "multiply_hook")
212+
213+
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
214+
output1 = self.model(input).mean().detach().cpu().item()
215+
216+
registry.remove_hook("multiply_hook")
217+
new_input = input * 2
218+
output2 = self.model(new_input).mean().detach().cpu().item()
219+
220+
registry.remove_hook("add_hook")
221+
new_input = input * 2 + 1
222+
output3 = self.model(new_input).mean().detach().cpu().item()
223+
224+
self.assertAlmostEqual(output1, output2, places=5)
225+
self.assertAlmostEqual(output1, output3, places=5)
226+
227+
def test_skip_layer_hook(self):
228+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
229+
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
230+
231+
input = torch.zeros(1, 4, device=torch_device)
232+
output = self.model(input).mean().detach().cpu().item()
233+
self.assertEqual(output, 0.0)
234+
235+
registry.remove_hook("skip_layer_hook")
236+
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
237+
output = self.model(input).mean().detach().cpu().item()
238+
self.assertNotEqual(output, 0.0)
239+
240+
def test_skip_layer_internal_block(self):
241+
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
242+
input = torch.zeros(1, 4, device=torch_device)
243+
244+
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
245+
with self.assertRaises(RuntimeError) as cm:
246+
self.model(input).mean().detach().cpu().item()
247+
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
248+
249+
registry.remove_hook("skip_layer_hook")
250+
output = self.model(input).mean().detach().cpu().item()
251+
self.assertNotEqual(output, 0.0)
252+
253+
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
254+
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
255+
output = self.model(input).mean().detach().cpu().item()
256+
self.assertNotEqual(output, 0.0)
257+
258+
def test_invocation_order_stateful_first(self):
259+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
260+
registry.register_hook(StatefulAddHook(1), "add_hook")
261+
registry.register_hook(AddHook(2), "add_hook_2")
262+
registry.register_hook(MultiplyHook(3), "multiply_hook")
263+
264+
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
265+
266+
logger = get_logger(__name__)
267+
logger.setLevel("DEBUG")
268+
269+
with CaptureLogger(logger) as cap_logger:
270+
self.model(input)
271+
output = cap_logger.out.replace(" ", "").replace("\n", "")
272+
expected_invocation_order_log = (
273+
(
274+
"MultiplyHook pre_forward\n"
275+
"AddHook pre_forward\n"
276+
"StatefulAddHook pre_forward\n"
277+
"AddHook post_forward\n"
278+
"MultiplyHook post_forward\n"
279+
)
280+
.replace(" ", "")
281+
.replace("\n", "")
282+
)
283+
self.assertEqual(output, expected_invocation_order_log)
284+
285+
registry.remove_hook("add_hook")
286+
with CaptureLogger(logger) as cap_logger:
287+
self.model(input)
288+
output = cap_logger.out.replace(" ", "").replace("\n", "")
289+
expected_invocation_order_log = (
290+
(
291+
"MultiplyHook pre_forward\n"
292+
"AddHook pre_forward\n"
293+
"AddHook post_forward\n"
294+
"MultiplyHook post_forward\n"
295+
)
296+
.replace(" ", "")
297+
.replace("\n", "")
298+
)
299+
self.assertEqual(output, expected_invocation_order_log)
300+
301+
def test_invocation_order_stateful_middle(self):
302+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
303+
registry.register_hook(AddHook(2), "add_hook")
304+
registry.register_hook(StatefulAddHook(1), "add_hook_2")
305+
registry.register_hook(MultiplyHook(3), "multiply_hook")
306+
307+
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
308+
309+
logger = get_logger(__name__)
310+
logger.setLevel("DEBUG")
311+
312+
with CaptureLogger(logger) as cap_logger:
313+
self.model(input)
314+
output = cap_logger.out.replace(" ", "").replace("\n", "")
315+
expected_invocation_order_log = (
316+
(
317+
"MultiplyHook pre_forward\n"
318+
"StatefulAddHook pre_forward\n"
319+
"AddHook pre_forward\n"
320+
"AddHook post_forward\n"
321+
"MultiplyHook post_forward\n"
322+
)
323+
.replace(" ", "")
324+
.replace("\n", "")
325+
)
326+
self.assertEqual(output, expected_invocation_order_log)
327+
328+
registry.remove_hook("add_hook")
329+
with CaptureLogger(logger) as cap_logger:
330+
self.model(input)
331+
output = cap_logger.out.replace(" ", "").replace("\n", "")
332+
expected_invocation_order_log = (
333+
("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n")
334+
.replace(" ", "")
335+
.replace("\n", "")
336+
)
337+
self.assertEqual(output, expected_invocation_order_log)
338+
339+
registry.remove_hook("add_hook_2")
340+
with CaptureLogger(logger) as cap_logger:
341+
self.model(input)
342+
output = cap_logger.out.replace(" ", "").replace("\n", "")
343+
expected_invocation_order_log = (
344+
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
345+
)
346+
self.assertEqual(output, expected_invocation_order_log)
347+
348+
def test_invocation_order_stateful_last(self):
349+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
350+
registry.register_hook(AddHook(1), "add_hook")
351+
registry.register_hook(MultiplyHook(2), "multiply_hook")
352+
registry.register_hook(StatefulAddHook(3), "add_hook_2")
353+
354+
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
355+
356+
logger = get_logger(__name__)
357+
logger.setLevel("DEBUG")
358+
359+
with CaptureLogger(logger) as cap_logger:
360+
self.model(input)
361+
output = cap_logger.out.replace(" ", "").replace("\n", "")
362+
expected_invocation_order_log = (
363+
(
364+
"StatefulAddHook pre_forward\n"
365+
"MultiplyHook pre_forward\n"
366+
"AddHook pre_forward\n"
367+
"AddHook post_forward\n"
368+
"MultiplyHook post_forward\n"
369+
)
370+
.replace(" ", "")
371+
.replace("\n", "")
372+
)
373+
self.assertEqual(output, expected_invocation_order_log)
374+
375+
registry.remove_hook("add_hook")
376+
with CaptureLogger(logger) as cap_logger:
377+
self.model(input)
378+
output = cap_logger.out.replace(" ", "").replace("\n", "")
379+
expected_invocation_order_log = (
380+
("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n")
381+
.replace(" ", "")
382+
.replace("\n", "")
383+
)
384+
self.assertEqual(output, expected_invocation_order_log)

0 commit comments

Comments
 (0)