Skip to content

Commit 658e24e

Browse files
a-r-r-o-woahzxlstevhliuDN6
authored
[core] Pyramid Attention Broadcast (#9562)
* start pyramid attention broadcast * add coauthor Co-Authored-By: Xuanlei Zhao <[email protected]> * update * make style * update * make style * add docs * add tests * update * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <[email protected]> * Pyramid Attention Broadcast rewrite + introduce hooks (#9826) * rewrite implementation with hooks * make style * update * merge pyramid-attention-rewrite-2 * make style * remove changes from latte transformer * revert docs changes * better debug message * add todos for future * update tests * make style * cleanup * fix * improve log message; fix latte test * refactor * update * update * update * revert changes to tests * update docs * update tests * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * update * fix flux test * reorder * refactor * make fix-copies * update docs * fixes * more fixes * make style * update tests * update code example * make fix-copies * refactor based on reviews * use maybe_free_model_hooks * CacheMixin * make style * update * add current_timestep property; update docs * make fix-copies * update * improve tests * try circular import fix * apply suggestions from review * address review comments * Apply suggestions from code review * refactor hook implementation * add test suite for hooks * PAB Refactor (#10667) * update * update * update --------- Co-authored-by: DN6 <[email protected]> * update * fix remove hook behaviour --------- Co-authored-by: Xuanlei Zhao <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: DN6 <[email protected]>
1 parent fb42066 commit 658e24e

32 files changed

+1256
-67
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,8 @@
598598
title: Attention Processor
599599
- local: api/activations
600600
title: Custom activation functions
601+
- local: api/cache
602+
title: Caching methods
601603
- local: api/normalization
602604
title: Custom normalization layers
603605
- local: api/utilities

docs/source/en/api/cache.md

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# Caching methods
13+
14+
## Pyramid Attention Broadcast
15+
16+
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
17+
18+
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
19+
20+
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
21+
22+
```python
23+
import torch
24+
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
25+
26+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
27+
pipe.to("cuda")
28+
29+
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
30+
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
31+
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
32+
# poorer quality of generated videos.
33+
config = PyramidAttentionBroadcastConfig(
34+
spatial_attention_block_skip_range=2,
35+
spatial_attention_timestep_skip_range=(100, 800),
36+
current_timestep_callback=lambda: pipe.current_timestep,
37+
)
38+
pipe.transformer.enable_cache(config)
39+
```
40+
41+
### CacheMixin
42+
43+
[[autodoc]] CacheMixin
44+
45+
### PyramidAttentionBroadcastConfig
46+
47+
[[autodoc]] PyramidAttentionBroadcastConfig
48+
49+
[[autodoc]] apply_pyramid_attention_broadcast

src/diffusers/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
_import_structure = {
3030
"configuration_utils": ["ConfigMixin"],
31+
"hooks": [],
3132
"loaders": ["FromOriginalModelMixin"],
3233
"models": [],
3334
"pipelines": [],
@@ -75,6 +76,13 @@
7576
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
7677

7778
else:
79+
_import_structure["hooks"].extend(
80+
[
81+
"HookRegistry",
82+
"PyramidAttentionBroadcastConfig",
83+
"apply_pyramid_attention_broadcast",
84+
]
85+
)
7886
_import_structure["models"].extend(
7987
[
8088
"AllegroTransformer3DModel",
@@ -90,6 +98,7 @@
9098
"AutoencoderKLTemporalDecoder",
9199
"AutoencoderOobleck",
92100
"AutoencoderTiny",
101+
"CacheMixin",
93102
"CogVideoXTransformer3DModel",
94103
"CogView3PlusTransformer2DModel",
95104
"ConsisIDTransformer3DModel",
@@ -588,6 +597,7 @@
588597
except OptionalDependencyNotAvailable:
589598
from .utils.dummy_pt_objects import * # noqa F403
590599
else:
600+
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
591601
from .models import (
592602
AllegroTransformer3DModel,
593603
AsymmetricAutoencoderKL,
@@ -602,6 +612,7 @@
602612
AutoencoderKLTemporalDecoder,
603613
AutoencoderOobleck,
604614
AutoencoderTiny,
615+
CacheMixin,
605616
CogVideoXTransformer3DModel,
606617
CogView3PlusTransformer2DModel,
607618
ConsisIDTransformer3DModel,

src/diffusers/hooks/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22

33

44
if is_torch_available():
5+
from .hooks import HookRegistry, ModelHook
56
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
7+
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

src/diffusers/hooks/hooks.py

+78-30
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class ModelHook:
3030

3131
_is_stateful = False
3232

33+
def __init__(self):
34+
self.fn_ref: "HookFunctionReference" = None
35+
3336
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3437
r"""
3538
Hook that is executed when a model is initialized.
@@ -48,8 +51,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
4851
module (`torch.nn.Module`):
4952
The module attached to this hook.
5053
"""
51-
module.forward = module._old_forward
52-
del module._old_forward
5354
return module
5455

5556
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
@@ -99,6 +100,29 @@ def reset_state(self, module: torch.nn.Module):
99100
return module
100101

101102

103+
class HookFunctionReference:
104+
def __init__(self) -> None:
105+
"""A container class that maintains mutable references to forward pass functions in a hook chain.
106+
107+
Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
108+
entire forward pass structure.
109+
110+
Attributes:
111+
pre_forward: A callable that processes inputs before the main forward pass.
112+
post_forward: A callable that processes outputs after the main forward pass.
113+
forward: The current forward function in the hook chain.
114+
original_forward: The original forward function, stored when a hook provides a custom new_forward.
115+
116+
The class enables hook removal by allowing updates to the forward chain through reference modification rather
117+
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
118+
be updated, preserving the execution order of the remaining hooks.
119+
"""
120+
self.pre_forward = None
121+
self.post_forward = None
122+
self.forward = None
123+
self.original_forward = None
124+
125+
102126
class HookRegistry:
103127
def __init__(self, module_ref: torch.nn.Module) -> None:
104128
super().__init__()
@@ -107,51 +131,71 @@ def __init__(self, module_ref: torch.nn.Module) -> None:
107131

108132
self._module_ref = module_ref
109133
self._hook_order = []
134+
self._fn_refs = []
110135

111136
def register_hook(self, hook: ModelHook, name: str) -> None:
112137
if name in self.hooks.keys():
113-
logger.warning(f"Hook with name {name} already exists, replacing it.")
114-
115-
if hasattr(self._module_ref, "_old_forward"):
116-
old_forward = self._module_ref._old_forward
117-
else:
118-
old_forward = self._module_ref.forward
119-
self._module_ref._old_forward = self._module_ref.forward
138+
raise ValueError(
139+
f"Hook with name {name} already exists in the registry. Please use a different name or "
140+
f"first remove the existing hook and then add a new one."
141+
)
120142

121143
self._module_ref = hook.initialize_hook(self._module_ref)
122144

123-
if hasattr(hook, "new_forward"):
124-
rewritten_forward = hook.new_forward
125-
145+
def create_new_forward(function_reference: HookFunctionReference):
126146
def new_forward(module, *args, **kwargs):
127-
args, kwargs = hook.pre_forward(module, *args, **kwargs)
128-
output = rewritten_forward(module, *args, **kwargs)
129-
return hook.post_forward(module, output)
130-
else:
147+
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
148+
output = function_reference.forward(*args, **kwargs)
149+
return function_reference.post_forward(module, output)
131150

132-
def new_forward(module, *args, **kwargs):
133-
args, kwargs = hook.pre_forward(module, *args, **kwargs)
134-
output = old_forward(*args, **kwargs)
135-
return hook.post_forward(module, output)
151+
return new_forward
152+
153+
forward = self._module_ref.forward
136154

155+
fn_ref = HookFunctionReference()
156+
fn_ref.pre_forward = hook.pre_forward
157+
fn_ref.post_forward = hook.post_forward
158+
fn_ref.forward = forward
159+
160+
if hasattr(hook, "new_forward"):
161+
fn_ref.original_forward = forward
162+
fn_ref.forward = functools.update_wrapper(
163+
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
164+
)
165+
166+
rewritten_forward = create_new_forward(fn_ref)
137167
self._module_ref.forward = functools.update_wrapper(
138-
functools.partial(new_forward, self._module_ref), old_forward
168+
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
139169
)
140170

171+
hook.fn_ref = fn_ref
141172
self.hooks[name] = hook
142173
self._hook_order.append(name)
174+
self._fn_refs.append(fn_ref)
143175

144176
def get_hook(self, name: str) -> Optional[ModelHook]:
145-
if name not in self.hooks.keys():
146-
return None
147-
return self.hooks[name]
177+
return self.hooks.get(name, None)
148178

149179
def remove_hook(self, name: str, recurse: bool = True) -> None:
150180
if name in self.hooks.keys():
181+
num_hooks = len(self._hook_order)
151182
hook = self.hooks[name]
183+
index = self._hook_order.index(name)
184+
fn_ref = self._fn_refs[index]
185+
186+
old_forward = fn_ref.forward
187+
if fn_ref.original_forward is not None:
188+
old_forward = fn_ref.original_forward
189+
190+
if index == num_hooks - 1:
191+
self._module_ref.forward = old_forward
192+
else:
193+
self._fn_refs[index + 1].forward = old_forward
194+
152195
self._module_ref = hook.deinitalize_hook(self._module_ref)
153196
del self.hooks[name]
154-
self._hook_order.remove(name)
197+
self._hook_order.pop(index)
198+
self._fn_refs.pop(index)
155199

156200
if recurse:
157201
for module_name, module in self._module_ref.named_modules():
@@ -161,7 +205,7 @@ def remove_hook(self, name: str, recurse: bool = True) -> None:
161205
module._diffusers_hook.remove_hook(name, recurse=False)
162206

163207
def reset_stateful_hooks(self, recurse: bool = True) -> None:
164-
for hook_name in self._hook_order:
208+
for hook_name in reversed(self._hook_order):
165209
hook = self.hooks[hook_name]
166210
if hook._is_stateful:
167211
hook.reset_state(self._module_ref)
@@ -180,9 +224,13 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry
180224
return module._diffusers_hook
181225

182226
def __repr__(self) -> str:
183-
hook_repr = ""
227+
registry_repr = ""
184228
for i, hook_name in enumerate(self._hook_order):
185-
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
229+
if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
230+
hook_repr = self.hooks[hook_name].__repr__()
231+
else:
232+
hook_repr = self.hooks[hook_name].__class__.__name__
233+
registry_repr += f" ({i}) {hook_name} - {hook_repr}"
186234
if i < len(self._hook_order) - 1:
187-
hook_repr += "\n"
188-
return f"HookRegistry(\n{hook_repr}\n)"
235+
registry_repr += "\n"
236+
return f"HookRegistry(\n{registry_repr}\n)"

0 commit comments

Comments
 (0)