Skip to content

Commit b30cf5d

Browse files
committed
spatio temporal guidance
1 parent 357f4f0 commit b30cf5d

File tree

4 files changed

+161
-40
lines changed

4 files changed

+161
-40
lines changed

src/diffusers/guiders/skip_layer_guidance.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
class SkipLayerGuidance(GuidanceMixin):
2626
"""
27-
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
27+
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG):
28+
https://huggingface.co/papers/2411.18664
2829
2930
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
3031
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
@@ -36,6 +37,9 @@ class SkipLayerGuidance(GuidanceMixin):
3637
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
3738
version of the model for the conditional prediction).
3839
40+
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
41+
generation quality in video diffusion models.
42+
3943
Additional reading:
4044
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
4145
@@ -54,13 +58,13 @@ class SkipLayerGuidance(GuidanceMixin):
5458
The fraction of the total number of denoising steps after which skip layer guidance starts.
5559
skip_layer_guidance_stop (`float`, defaults to `0.2`):
5660
The fraction of the total number of denoising steps after which skip layer guidance stops.
57-
skip_guidance_layers (`int` or `List[int]`, *optional*):
61+
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
5862
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
5963
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
6064
3.5 Medium.
6165
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
6266
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
63-
`LayerSkipConfig`. If not provided, `skip_guidance_layers` must be provided.
67+
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
6468
guidance_rescale (`float`, defaults to `0.0`):
6569
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
6670
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
@@ -79,7 +83,7 @@ def __init__(
7983
skip_layer_guidance_scale: float = 2.8,
8084
skip_layer_guidance_start: float = 0.01,
8185
skip_layer_guidance_stop: float = 0.2,
82-
skip_guidance_layers: Optional[Union[int, List[int]]] = None,
86+
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
8387
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
8488
guidance_rescale: float = 0.0,
8589
use_original_formulation: bool = False,
@@ -102,21 +106,21 @@ def __init__(
102106
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
103107
)
104108

105-
if skip_guidance_layers is None and skip_layer_config is None:
109+
if skip_layer_guidance_layers is None and skip_layer_config is None:
106110
raise ValueError(
107-
"Either `skip_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
111+
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
108112
)
109-
if skip_guidance_layers is not None and skip_layer_config is not None:
110-
raise ValueError("Only one of `skip_guidance_layers` or `skip_layer_config` can be provided.")
113+
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
114+
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
111115

112-
if skip_guidance_layers is not None:
113-
if isinstance(skip_guidance_layers, int):
114-
skip_guidance_layers = [skip_guidance_layers]
115-
if not isinstance(skip_guidance_layers, list):
116+
if skip_layer_guidance_layers is not None:
117+
if isinstance(skip_layer_guidance_layers, int):
118+
skip_layer_guidance_layers = [skip_layer_guidance_layers]
119+
if not isinstance(skip_layer_guidance_layers, list):
116120
raise ValueError(
117-
f"Expected `skip_guidance_layers` to be an int or a list of ints, but got {type(skip_guidance_layers)}."
121+
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
118122
)
119-
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_guidance_layers]
123+
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
120124

121125
if isinstance(skip_layer_config, LayerSkipConfig):
122126
skip_layer_config = [skip_layer_config]

src/diffusers/hooks/_common.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from ..models.attention import FeedForward, LuminaFeedForward
1516
from ..models.attention_processor import Attention, MochiAttention
1617

1718

1819
_ATTENTION_CLASSES = (Attention, MochiAttention)
20+
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
1921

2022
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
2123
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)

src/diffusers/hooks/_helpers.py

+61-18
Original file line numberDiff line numberDiff line change
@@ -33,32 +33,37 @@
3333
from ..models.transformers.transformer_wan import WanTransformerBlock
3434

3535

36+
@dataclass
37+
class AttentionProcessorMetadata:
38+
skip_processor_output_fn: Callable[[Any], Any]
39+
40+
41+
@dataclass
42+
class GuidanceMetadata:
43+
perturbed_attention_guidance_processor_cls: Type = None
44+
45+
3646
@dataclass
3747
class TransformerBlockMetadata:
3848
skip_block_output_fn: Callable[[Any], Any]
3949
return_hidden_states_index: int = None
4050
return_encoder_hidden_states_index: int = None
4151

4252

43-
class TransformerBlockRegistry:
53+
class AttentionProcessorRegistry:
4454
_registry = {}
4555

4656
@classmethod
47-
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
57+
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
4858
cls._registry[model_class] = metadata
4959

5060
@classmethod
51-
def get(cls, model_class: Type) -> TransformerBlockMetadata:
61+
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
5262
if model_class not in cls._registry:
5363
raise ValueError(f"Model class {model_class} not registered.")
5464
return cls._registry[model_class]
5565

5666

57-
@dataclass
58-
class GuidanceMetadata:
59-
perturbed_attention_guidance_processor_cls: Type = None
60-
61-
6267
class GuidanceMetadataRegistry:
6368
_registry = {}
6469

@@ -73,6 +78,40 @@ def get(cls, model_class: Type) -> GuidanceMetadata:
7378
return cls._registry[model_class]
7479

7580

81+
class TransformerBlockRegistry:
82+
_registry = {}
83+
84+
@classmethod
85+
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
86+
cls._registry[model_class] = metadata
87+
88+
@classmethod
89+
def get(cls, model_class: Type) -> TransformerBlockMetadata:
90+
if model_class not in cls._registry:
91+
raise ValueError(f"Model class {model_class} not registered.")
92+
return cls._registry[model_class]
93+
94+
95+
def _register_attention_processors_metadata():
96+
# CogView4
97+
AttentionProcessorRegistry.register(
98+
model_class=CogView4AttnProcessor,
99+
metadata=AttentionProcessorMetadata(
100+
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
101+
),
102+
)
103+
104+
105+
def _register_guidance_metadata():
106+
# CogView4
107+
GuidanceMetadataRegistry.register(
108+
model_class=CogView4AttnProcessor,
109+
metadata=GuidanceMetadata(
110+
perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor,
111+
),
112+
)
113+
114+
76115
def _register_transformer_blocks_metadata():
77116
# CogVideoX
78117
TransformerBlockRegistry.register(
@@ -177,17 +216,20 @@ def _register_transformer_blocks_metadata():
177216
)
178217

179218

180-
def _register_guidance_metadata():
181-
# CogView4
182-
GuidanceMetadataRegistry.register(
183-
model_class=CogView4AttnProcessor,
184-
metadata=GuidanceMetadata(
185-
perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor,
186-
),
187-
)
219+
# fmt: off
220+
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
221+
hidden_states = kwargs.get("hidden_states", None)
222+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
223+
if hidden_states is None and len(args) > 0:
224+
hidden_states = args[0]
225+
if encoder_hidden_states is None and len(args) > 1:
226+
encoder_hidden_states = args[1]
227+
return hidden_states, encoder_hidden_states
228+
229+
230+
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
188231

189232

190-
# fmt: off
191233
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
192234
hidden_states = kwargs.get("hidden_states", None)
193235
if hidden_states is None and len(args) > 0:
@@ -229,5 +271,6 @@ def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___en
229271
# fmt: on
230272

231273

232-
_register_transformer_blocks_metadata()
274+
_register_attention_processors_metadata()
233275
_register_guidance_metadata()
276+
_register_transformer_blocks_metadata()

src/diffusers/hooks/layer_skip.py

+80-8
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import List, Optional
16+
from typing import Callable, List, Optional
1717

1818
import torch
1919

2020
from ..utils import get_logger
2121
from ..utils.torch_utils import unwrap_module
22-
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
23-
from ._helpers import TransformerBlockRegistry
22+
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
23+
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
2424
from .hooks import HookRegistry, ModelHook
2525

2626

@@ -44,9 +44,50 @@ class LayerSkipConfig:
4444

4545
indices: List[int]
4646
fqn: str = "auto"
47+
skip_attention: bool = True
48+
skip_attention_scores: bool = False
49+
skip_ff: bool = True
4750

4851

49-
class LayerSkipHook(ModelHook):
52+
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
53+
def __init__(self) -> None:
54+
super().__init__()
55+
56+
def __torch_function__(self, func, types, args=(), kwargs=None):
57+
if kwargs is None:
58+
kwargs = {}
59+
if func is torch.nn.functional.scaled_dot_product_attention:
60+
value = kwargs.get("value", None)
61+
if value is None:
62+
value = args[2]
63+
return value
64+
return func(*args, **kwargs)
65+
66+
67+
class AttentionProcessorSkipHook(ModelHook):
68+
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False):
69+
self.skip_processor_output_fn = skip_processor_output_fn
70+
self.skip_attention_scores = skip_attention_scores
71+
72+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
73+
if self.skip_attention_scores:
74+
with AttentionScoreSkipFunctionMode():
75+
return self.fn_ref.original_forward(*args, **kwargs)
76+
else:
77+
return self.skip_processor_output_fn(module, *args, **kwargs)
78+
79+
80+
class FeedForwardSkipHook(ModelHook):
81+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
82+
output = kwargs.get("hidden_states", None)
83+
if output is None:
84+
output = kwargs.get("x", None)
85+
if output is None and len(args) > 0:
86+
output = args[0]
87+
return output
88+
89+
90+
class TransformerBlockSkipHook(ModelHook):
5091
def initialize_hook(self, module):
5192
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
5293
return module
@@ -81,6 +122,9 @@ def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
81122
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
82123
name = name or _LAYER_SKIP_HOOK
83124

125+
if config.skip_attention and config.skip_attention_scores:
126+
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
127+
84128
if config.fqn == "auto":
85129
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
86130
if hasattr(module, identifier):
@@ -101,10 +145,38 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
101145
if len(config.indices) == 0:
102146
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
103147

148+
blocks_found = False
104149
for i, block in enumerate(transformer_blocks):
105150
if i not in config.indices:
106151
continue
107-
logger.debug(f"Apply LayerSkipHook to '{config.fqn}.{i}'")
108-
registry = HookRegistry.check_if_exists_or_initialize(block)
109-
hook = LayerSkipHook()
110-
registry.register_hook(hook, name)
152+
blocks_found = True
153+
if config.skip_attention and config.skip_ff:
154+
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
155+
registry = HookRegistry.check_if_exists_or_initialize(block)
156+
hook = TransformerBlockSkipHook()
157+
registry.register_hook(hook, name)
158+
elif config.skip_attention or config.skip_attention_scores:
159+
for submodule_name, submodule in block.named_modules():
160+
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
161+
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
162+
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
163+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
164+
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores)
165+
registry.register_hook(hook, name)
166+
elif config.skip_ff:
167+
for submodule_name, submodule in block.named_modules():
168+
if isinstance(submodule, _FEEDFORWARD_CLASSES):
169+
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
170+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
171+
hook = FeedForwardSkipHook()
172+
registry.register_hook(hook, name)
173+
else:
174+
raise ValueError(
175+
"At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
176+
)
177+
178+
if not blocks_found:
179+
raise ValueError(
180+
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
181+
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
182+
)

0 commit comments

Comments
 (0)