From b79928d8e0266e6a94add3ecc4828c039a6db614 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 20 Jan 2025 13:29:15 +0100 Subject: [PATCH 1/7] update --- src/diffusers/models/modeling_utils.py | 62 +++++++++++++++++-- .../models/transformers/transformer_ltx.py | 12 +--- 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1c2b9a76dd67..2cc524d21318 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -27,6 +27,7 @@ import safetensors import torch +import torch.utils.checkpoint from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards from huggingface_hub.utils import validate_hf_hub_args from torch import Tensor, nn @@ -154,6 +155,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): def __init__(self): super().__init__() + self._gradient_checkpointing_func = None + def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite @@ -179,14 +182,47 @@ def is_gradient_checkpointing(self) -> bool: """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) - def enable_gradient_checkpointing(self) -> None: + def enable_gradient_checkpointing( + self, + gradient_checkpointing_func: Optional[Callable] = None, + gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: """ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). """ if not self._supports_gradient_checkpointing: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute " + f"`_supports_gradient_checkpointing` to `True` in the class definition." + ) + + user_provided_gradient_checkpointing_func = gradient_checkpointing_func is not None + if gradient_checkpointing_func is None: + + def _gradient_checkpointing_func(module, *args): + ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + return torch.utils.checkpoint.checkpoint( + module.__call__, + *args, + **ckpt_kwargs, + ) + + gradient_checkpointing_func = _gradient_checkpointing_func + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + if ( + not user_provided_gradient_checkpointing_func + and is_torch_version(">=", "1.11.0") + and inspect.signature(gradient_checkpointing_func).parameters.get("use_reentrant") is not None + ): + gradient_checkpointing_kwargs["use_reentrant"] = False + + gradient_checkpointing_func = partial(gradient_checkpointing_func, **gradient_checkpointing_kwargs) + + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) def disable_gradient_checkpointing(self) -> None: """ @@ -194,7 +230,7 @@ def disable_gradient_checkpointing(self) -> None: *checkpoint activations* in other frameworks). """ if self._supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) + self._set_gradient_checkpointing(enable=False) def set_use_npu_flash_attention(self, valid: bool) -> None: r""" @@ -1354,6 +1390,24 @@ def get_memory_footprint(self, return_buffers=True): mem = mem + mem_bufs return mem + def _set_gradient_checkpointing( + self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint + ) -> None: + is_gradient_checkpointing_set = False + + for name, module in self.named_modules(): + if hasattr(module, "gradient_checkpointing"): + logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'") + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to use a module that supports gradient checkpointing " + f"by creating a boolean attribute `gradient_checkpointing` in the module and setting it to `True`." + ) + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: deprecated_attention_block_paths = [] diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index a895340bd124..3b83e186d0d8 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention @@ -360,10 +360,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -426,15 +422,13 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, encoder_attention_mask, - **ckpt_kwargs, ) else: hidden_states = block( From 325e7408cf0c7b1c8e684d1082128beed49b71ce Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 20 Jan 2025 13:45:47 +0100 Subject: [PATCH 2/7] remove unused fn --- src/diffusers/models/transformers/transformer_ltx.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 3b83e186d0d8..c037b3e9447d 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -412,16 +412,6 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - hidden_states = self._gradient_checkpointing_func( block, hidden_states, From 3f7aa5337de9024e7b56e1d0c2b914d5cd47f6f4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 22 Jan 2025 19:24:44 +0100 Subject: [PATCH 3/7] apply suggestions based on review --- src/diffusers/models/modeling_utils.py | 30 ++++++++------------------ 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4f73c18f11a3..3ef40ffb5783 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -21,7 +21,7 @@ import os import re from collections import OrderedDict -from functools import partial, wraps +from functools import wraps from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union @@ -196,14 +196,15 @@ def is_gradient_checkpointing(self) -> bool: """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) - def enable_gradient_checkpointing( - self, - gradient_checkpointing_func: Optional[Callable] = None, - gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: + def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None: """ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). + + Args: + gradient_checkpointing_func (`Callable`, *optional*): + The function to use for gradient checkpointing. If `None`, the default PyTorch checkpointing function + is used (`torch.utils.checkpoint.checkpoint`). """ if not self._supports_gradient_checkpointing: raise ValueError( @@ -211,7 +212,6 @@ def enable_gradient_checkpointing( f"`_supports_gradient_checkpointing` to `True` in the class definition." ) - user_provided_gradient_checkpointing_func = gradient_checkpointing_func is not None if gradient_checkpointing_func is None: def _gradient_checkpointing_func(module, *args): @@ -224,18 +224,6 @@ def _gradient_checkpointing_func(module, *args): gradient_checkpointing_func = _gradient_checkpointing_func - if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {} - - if ( - not user_provided_gradient_checkpointing_func - and is_torch_version(">=", "1.11.0") - and inspect.signature(gradient_checkpointing_func).parameters.get("use_reentrant") is not None - ): - gradient_checkpointing_kwargs["use_reentrant"] = False - - gradient_checkpointing_func = partial(gradient_checkpointing_func, **gradient_checkpointing_kwargs) - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) def disable_gradient_checkpointing(self) -> None: @@ -1502,8 +1490,8 @@ def _set_gradient_checkpointing( if not is_gradient_checkpointing_set: raise ValueError( - f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to use a module that supports gradient checkpointing " - f"by creating a boolean attribute `gradient_checkpointing` in the module and setting it to `True`." + f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to " + f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`." ) def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: From d0c3aaeec47b0a725b3f8a44a661e548e24c1f87 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 25 Jan 2025 17:35:07 +0100 Subject: [PATCH 4/7] =?UTF-8?q?update=20+=20cleanup=20=F0=9F=A7=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models/autoencoders/autoencoder_kl.py | 4 - .../autoencoders/autoencoder_kl_allegro.py | 26 +- .../autoencoders/autoencoder_kl_cogvideox.py | 67 +---- .../autoencoder_kl_hunyuan_video.py | 100 +------ .../models/autoencoders/autoencoder_kl_ltx.py | 61 +--- .../autoencoders/autoencoder_kl_mochi.py | 67 +---- .../autoencoder_kl_temporal_decoder.py | 53 +--- .../models/autoencoders/autoencoder_tiny.py | 4 - src/diffusers/models/autoencoders/vae.py | 172 +++--------- .../models/controlnets/controlnet.py | 6 - .../models/controlnets/controlnet_flux.py | 38 +-- .../models/controlnets/controlnet_sd3.py | 26 +- .../controlnets/controlnet_sparsectrl.py | 4 - .../models/controlnets/controlnet_union.py | 6 - .../models/controlnets/controlnet_xs.py | 48 +--- .../transformers/auraflow_transformer_2d.py | 40 +-- .../transformers/cogvideox_transformer_3d.py | 18 +- .../transformers/consisid_transformer_3d.py | 20 +- .../models/transformers/dit_transformer_2d.py | 22 +- .../transformers/latte_transformer_3d.py | 9 +- .../transformers/pixart_transformer_2d.py | 22 +- .../models/transformers/sana_transformer.py | 23 +- .../transformers/stable_audio_transformer.py | 24 +- .../models/transformers/transformer_2d.py | 22 +- .../transformers/transformer_allegro.py | 20 +- .../transformers/transformer_cogview3plus.py | 21 +- .../models/transformers/transformer_flux.py | 38 +-- .../transformers/transformer_hunyuan_video.py | 28 +- .../models/transformers/transformer_mochi.py | 19 +- .../models/transformers/transformer_sd3.py | 22 +- .../transformers/transformer_temporal.py | 14 +- src/diffusers/models/unets/unet_2d.py | 4 - src/diffusers/models/unets/unet_2d_blocks.py | 260 ++---------------- .../models/unets/unet_2d_condition.py | 4 - src/diffusers/models/unets/unet_3d_blocks.py | 141 +--------- .../models/unets/unet_3d_condition.py | 8 - src/diffusers/models/unets/unet_i2vgen_xl.py | 9 - src/diffusers/models/unets/unet_kandinsky3.py | 4 - .../models/unets/unet_motion_model.py | 121 +------- .../unets/unet_spatio_temporal_condition.py | 4 - .../models/unets/unet_stable_cascade.py | 43 +-- src/diffusers/models/unets/uvit_2d.py | 3 - .../pipelines/audioldm2/modeling_audioldm2.py | 76 +---- .../blip_diffusion/modeling_blip2.py | 13 +- .../versatile_diffusion/modeling_text_unet.py | 110 +------- .../pipelines/kolors/text_encoder.py | 6 +- .../pipeline_latent_diffusion.py | 15 +- .../wuerstchen/modeling_wuerstchen_prior.py | 38 +-- 48 files changed, 246 insertions(+), 1657 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 9036c027a535..357df0c31087 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -138,10 +138,6 @@ def __init__( self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Encoder, Decoder)): - module.gradient_checkpointing = value - def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index b62ed67ade29..f79aabe91dd3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -507,19 +507,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = sample + residual if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # Down blocks for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + sample = self._gradient_checkpointing_func(down_block, sample) # Mid block - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + sample = self._gradient_checkpointing_func(self.mid_block, sample) else: # Down blocks for down_block in self.down_blocks: @@ -647,19 +640,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # Mid block - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + sample = self._gradient_checkpointing_func(self.mid_block, sample) # Up blocks for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + sample = self._gradient_checkpointing_func(up_block, sample) else: # Mid block @@ -809,10 +795,6 @@ def __init__( sample_size - self.tile_overlap_w, ) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): - module.gradient_checkpointing = value - def enable_tiling(self) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 941b3eb07f10..829e0fe54dd2 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -421,15 +421,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, temb, zq, @@ -523,15 +516,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key) + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key) ) else: hidden_states, new_conv_cache[conv_cache_key] = resnet( @@ -637,15 +623,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, temb, zq, @@ -774,18 +753,11 @@ def forward( hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # 1. Down for i, down_block in enumerate(self.down_blocks): conv_cache_key = f"down_block_{i}" - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + down_block, hidden_states, temb, None, @@ -793,8 +765,8 @@ def custom_forward(*inputs): ) # 2. Mid - hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), + hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func( + self.mid_block, hidden_states, temb, None, @@ -940,16 +912,9 @@ def forward( hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # 1. Mid - hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), + hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func( + self.mid_block, hidden_states, temb, sample, @@ -959,8 +924,8 @@ def custom_forward(*inputs): # 2. Up for i, up_block in enumerate(self.up_blocks): conv_cache_key = f"up_block_{i}" - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + up_block, hidden_states, temb, sample, @@ -1122,10 +1087,6 @@ def __init__( self.tile_overlap_factor_height = 1 / 6 self.tile_overlap_factor_width = 1 / 5 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): - module.gradient_checkpointing = value - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index e2236a7f20ad..df2a6eb0d319 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -21,7 +21,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..attention_processor import Attention @@ -252,21 +252,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: @@ -278,9 +264,7 @@ def custom_forward(*inputs): hidden_states = attn(hidden_states, attention_mask=attention_mask) hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) else: hidden_states = self.resnets[0](hidden_states) @@ -350,22 +334,8 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for resnet in self.resnets: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) else: for resnet in self.resnets: hidden_states = resnet(hidden_states) @@ -426,22 +396,8 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for resnet in self.resnets: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) else: for resnet in self.resnets: @@ -545,26 +501,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for down_block in self.down_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) else: for down_block in self.down_blocks: hidden_states = down_block(hidden_states) @@ -667,26 +607,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), hidden_states, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) else: hidden_states = self.mid_block(hidden_states) @@ -800,10 +724,6 @@ def __init__( self.tile_sample_stride_width = 192 self.tile_sample_stride_num_frames = 12 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): - module.gradient_checkpointing = value - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 25753afd5ce6..75709ca10dfe 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -338,16 +338,7 @@ def forward( for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, generator - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) else: hidden_states = resnet(hidden_states, temb, generator) @@ -438,16 +429,7 @@ def forward( for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, generator - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) else: hidden_states = resnet(hidden_states, temb, generator) @@ -573,16 +555,7 @@ def forward( for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, generator - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) else: hidden_states = resnet(hidden_states, temb, generator) @@ -697,17 +670,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - for down_block in self.down_blocks: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states) + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) else: for down_block in self.down_blocks: hidden_states = down_block(hidden_states) @@ -838,19 +804,10 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) else: hidden_states = self.mid_block(hidden_states, temb) @@ -1017,10 +974,6 @@ def __init__( self.tile_sample_stride_width = 448 self.tile_sample_stride_num_frames = 8 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)): - module.gradient_checkpointing = value - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 920b0b62fef6..cd3eff73ed64 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -207,15 +207,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key), ) @@ -312,15 +305,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key) ) else: hidden_states, new_conv_cache[conv_cache_key] = resnet( @@ -393,15 +379,8 @@ def forward( conv_cache_key = f"resnet_{i}" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key), ) @@ -531,21 +510,14 @@ def forward( hidden_states = hidden_states.permute(0, 4, 1, 2, 3) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") + hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func( + self.block_in, hidden_states, conv_cache=conv_cache.get("block_in") ) for i, down_block in enumerate(self.down_blocks): conv_cache_key = f"down_block_{i}" - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key) ) else: hidden_states, new_conv_cache["block_in"] = self.block_in( @@ -648,21 +620,14 @@ def forward( # 1. Mid if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def create_forward(*inputs): - return module(*inputs) - - return create_forward - - hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") + hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func( + self.block_in, hidden_states, conv_cache=conv_cache.get("block_in") ) for i, up_block in enumerate(self.up_blocks): conv_cache_key = f"up_block_{i}" - hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key) ) else: hidden_states, new_conv_cache["block_in"] = self.block_in( @@ -819,10 +784,6 @@ def __init__( self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (MochiEncoder3D, MochiDecoder3D)): - module.gradient_checkpointing = value - def enable_tiling( self, tile_sample_min_height: Optional[int] = None, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 38ad78c0707b..5a72cd395196 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -18,7 +18,6 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..modeling_outputs import AutoencoderKLOutput @@ -97,47 +96,21 @@ def forward( upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func( + self.mid_block, + sample, + image_only_indicator, + ) + sample = sample.to(upscale_dtype) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), - sample, - image_only_indicator, - use_reentrant=False, - ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - image_only_indicator, - use_reentrant=False, - ) - else: - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func( + up_block, sample, image_only_indicator, ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - image_only_indicator, - ) else: # middle sample = self.mid_block(sample, image_only_indicator=image_only_indicator) @@ -229,10 +202,6 @@ def __init__( self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Encoder, TemporalDecoder)): - module.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: diff --git a/src/diffusers/models/autoencoders/autoencoder_tiny.py b/src/diffusers/models/autoencoders/autoencoder_tiny.py index 35081c22dfc4..7ed727c55c37 100644 --- a/src/diffusers/models/autoencoders/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoders/autoencoder_tiny.py @@ -154,10 +154,6 @@ def __init__( self.register_to_config(block_out_channels=decoder_block_out_channels) self.register_to_config(force_upcast=False) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (EncoderTiny, DecoderTiny)): - module.gradient_checkpointing = value - def scale_latents(self, x: torch.Tensor) -> torch.Tensor: """raw latents -> [0, 1]""" return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 7fc7d5a4d797..72e0acda3afe 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from ...utils import BaseOutput, is_torch_version +from ...utils import BaseOutput from ...utils.torch_utils import randn_tensor from ..activations import get_activation from ..attention_processor import SpatialNorm @@ -156,28 +156,11 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.conv_in(sample) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - # down - if is_torch_version(">=", "1.11.0"): - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, use_reentrant=False - ) - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, use_reentrant=False - ) - else: - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + for down_block in self.down_blocks: + sample = self._gradient_checkpointing_func(down_block, sample) + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample) else: # down @@ -305,41 +288,13 @@ def forward( upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + sample = sample.to(upscale_dtype) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), - sample, - latent_embeds, - use_reentrant=False, - ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - latent_embeds, - use_reentrant=False, - ) - else: - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds - ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) @@ -558,72 +513,28 @@ def forward( upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + sample = sample.to(upscale_dtype) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), - sample, - latent_embeds, - use_reentrant=False, - ) - sample = sample.to(upscale_dtype) - - # condition encoder - if image is not None and mask is not None: - masked_image = (1 - mask) * image - im_x = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.condition_encoder), - masked_image, - mask, - use_reentrant=False, - ) - - # up - for up_block in self.up_blocks: - if image is not None and mask is not None: - sample_ = im_x[str(tuple(sample.shape))] - mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") - sample = sample * mask_ + sample_ * (1 - mask_) - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - latent_embeds, - use_reentrant=False, - ) - if image is not None and mask is not None: - sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) - else: - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self._gradient_checkpointing_func( + self.condition_encoder, + masked_image, + mask, ) - sample = sample.to(upscale_dtype) - # condition encoder - if image is not None and mask is not None: - masked_image = (1 - mask) * image - im_x = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.condition_encoder), - masked_image, - mask, - ) - - # up - for up_block in self.up_blocks: - if image is not None and mask is not None: - sample_ = im_x[str(tuple(sample.shape))] - mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") - sample = sample * mask_ + sample_ * (1 - mask_) - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + # up + for up_block in self.up_blocks: if image is not None and mask is not None: - sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) else: # middle sample = self.mid_block(sample, latent_embeds) @@ -890,17 +801,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: r"""The forward method of the `EncoderTiny` class.""" if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) - else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + x = self._gradient_checkpointing_func(self.layers, x) else: # scale image from [-1, 1] to [0, 1] to match TAESD convention @@ -976,18 +877,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.tanh(x / 3) * 3 if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) - else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) - + x = self._gradient_checkpointing_func(self.layers, x) else: x = self.layers(x) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 1453aaf4362c..7a6ca886caed 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -31,8 +31,6 @@ from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..unets.unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block, @@ -659,10 +657,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 923b41119624..51c34b7fe965 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -22,7 +22,7 @@ from ...loaders import PeftAdapterMixin from ...models.attention_processor import AttentionProcessor from ...models.modeling_utils import ModelMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -178,10 +178,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - @classmethod def from_transformer( cls, @@ -330,24 +326,12 @@ def forward( block_samples = () for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, ) else: @@ -364,23 +348,11 @@ def custom_forward(*inputs): single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 9e361f2b16e5..1b0b4bae6410 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import JointTransformerBlock from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed @@ -262,10 +262,6 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer # we should have handled this in conversion script def _get_pos_embed_from_transformer(self, transformer): @@ -382,30 +378,16 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} if self.context_embedder is not None: - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, - **ckpt_kwargs, ) else: # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), hidden_states, temb, **ckpt_kwargs - ) + hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb) else: if self.context_embedder is not None: diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index 807cbd339ef9..4edc91cacaa7 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -590,10 +590,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 1bf176101c61..076e966f3d37 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -29,8 +29,6 @@ from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..unets.unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) @@ -599,10 +597,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 8a8901d82d90..608be6b70277 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -20,7 +20,7 @@ from torch import Tensor, nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput, is_torch_version, logging +from ...utils import BaseOutput, logging from ...utils.torch_utils import apply_freeu from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -864,10 +864,6 @@ def freeze_unet_params(self) -> None: for u in self.up_blocks: u.freeze_base_params() - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -1450,15 +1446,6 @@ def forward( base_blocks = list(zip(self.base_resnets, self.base_attentions)) ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base ): @@ -1468,13 +1455,7 @@ def custom_forward(*inputs): # apply base subblock if torch.is_grad_enabled() and self.gradient_checkpointing: - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - h_base = torch.utils.checkpoint.checkpoint( - create_custom_forward(b_res), - h_base, - temb, - **ckpt_kwargs, - ) + h_base = self._gradient_checkpointing_func(b_res, h_base, temb) else: h_base = b_res(h_base, temb) @@ -1491,13 +1472,7 @@ def custom_forward(*inputs): # apply ctrl subblock if apply_control: if torch.is_grad_enabled() and self.gradient_checkpointing: - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - h_ctrl = torch.utils.checkpoint.checkpoint( - create_custom_forward(c_res), - h_ctrl, - temb, - **ckpt_kwargs, - ) + h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb) else: h_ctrl = c_res(h_ctrl, temb) if c_attn is not None: @@ -1862,15 +1837,6 @@ def forward( and getattr(self, "b2", None) ) - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): # FreeU: Only operate on the first two stages if is_freeu_enabled: @@ -1900,13 +1866,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): hidden_states = torch.cat([hidden_states, res_h_base], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index f1f36b87987d..4938ed23c506 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Union +from typing import Dict, Union import torch import torch.nn as nn @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin -from ...utils import is_torch_version, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( Attention, @@ -444,10 +444,6 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.FloatTensor, @@ -469,23 +465,11 @@ def forward( # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, - **ckpt_kwargs, ) else: @@ -500,22 +484,10 @@ def custom_forward(*inputs): for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - combined_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + combined_hidden_states = self._gradient_checkpointing_func( + block, combined_hidden_states, temb, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index c3039180b81d..797ad30017a6 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 @@ -330,9 +330,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -488,22 +485,13 @@ def forward( # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, emb, image_rotary_emb, attention_kwargs, - **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index 86a6628b5161..f312553e4c05 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0 @@ -595,9 +595,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - def _init_face_inputs(self): self.local_facial_extractor = LocalFacialExtractor( id_dim=self.LFE_id_dim, @@ -745,22 +742,13 @@ def forward( # 3. Transformer blocks ca_idx = 0 for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, emb, image_rotary_emb, - **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 7eac313c14db..6e83f49db71c 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...utils import logging from ..attention import BasicTransformerBlock from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -144,10 +144,6 @@ def __init__( self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -186,19 +182,8 @@ def forward( # 2. Blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, None, None, @@ -206,7 +191,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, class_labels, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index be06f44a9efe..cc8a52f28626 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -164,9 +164,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -241,7 +238,7 @@ def forward( zip(self.transformer_blocks, self.temporal_transformer_blocks) ): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self._gradient_checkpointing_func( spatial_block, hidden_states, None, # attention_mask @@ -250,7 +247,6 @@ def forward( timestep_spatial, None, # cross_attention_kwargs None, # class_labels - use_reentrant=False, ) else: hidden_states = spatial_block( @@ -274,7 +270,7 @@ def forward( hidden_states = hidden_states + self.temp_pos_embed if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = self._gradient_checkpointing_func( temp_block, hidden_states, None, # attention_mask @@ -283,7 +279,6 @@ def forward( timestep_temp, None, # cross_attention_kwargs None, # class_labels - use_reentrant=False, ) else: hidden_states = temp_block( diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index b1740cc08fdf..8e290074a018 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -17,7 +17,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...utils import logging from ..attention import BasicTransformerBlock from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 from ..embeddings import PatchEmbed, PixArtAlphaTextProjection @@ -184,10 +184,6 @@ def __init__( in_features=self.config.caption_channels, hidden_size=self.inner_dim ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -388,19 +384,8 @@ def forward( # 2. Blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -408,7 +393,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, None, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index a2a54406430d..cface676b409 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -19,7 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, AttentionProcessor, @@ -308,10 +308,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -438,21 +434,9 @@ def forward( # 2. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for block in self.transformer_blocks: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -460,7 +444,6 @@ def custom_forward(*inputs): timestep, post_patch_height, post_patch_width, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index bb370f20f21b..d81b6447adb0 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Dict, Optional, Union import numpy as np import torch @@ -29,7 +29,7 @@ ) from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_2d import Transformer2DModelOutput -from ...utils import is_torch_version, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph @@ -346,10 +346,6 @@ def set_default_attn_processor(self): """ self.set_attn_processor(StableAudioAttnProcessor2_0()) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.FloatTensor, @@ -416,25 +412,13 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, cross_attention_hidden_states, encoder_attention_mask, rotary_embedding, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 35e78877f27e..a88ee6c9c9b8 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import LegacyConfigMixin, register_to_config -from ...utils import deprecate, is_torch_version, logging +from ...utils import deprecate, logging from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput @@ -321,10 +321,6 @@ def _init_patched_inputs(self, norm_type): in_features=self.caption_channels, hidden_size=self.inner_dim ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -417,19 +413,8 @@ def forward( # 2. Blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -437,7 +422,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, class_labels, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index f32c38394ba4..44116289b5a3 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import AllegroAttnProcessor2_0, Attention @@ -303,9 +303,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -375,23 +372,14 @@ def forward( for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep, attention_mask, encoder_attention_mask, image_rotary_emb, - **ckpt_kwargs, ) else: hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 0376cc2fd70d..da7133791f37 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Union +from typing import Dict, Union import torch import torch.nn as nn @@ -27,7 +27,7 @@ ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous -from ...utils import is_torch_version, logging +from ...utils import logging from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..normalization import CogView3PlusAdaLayerNormZeroTextImage @@ -289,10 +289,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -344,20 +340,11 @@ def forward( for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, emb, - **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index db8d73856689..d39064bb17b0 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -32,7 +32,7 @@ ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed @@ -422,10 +422,6 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -520,24 +516,12 @@ def forward( for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, ) else: @@ -564,23 +548,11 @@ def custom_forward(*inputs): for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, temb, image_rotary_emb, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 210a2e711972..9a2cebb8c090 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor from ..embeddings import ( @@ -671,10 +671,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -733,38 +729,24 @@ def forward( # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for block in self.transformer_blocks: - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb, - **ckpt_kwargs, ) for block in self.single_transformer_blocks: - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb, - **ckpt_kwargs, ) else: diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index d16430f27931..708c91717015 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 @@ -403,10 +403,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -459,22 +455,13 @@ def forward( for i, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, encoder_attention_mask, image_rotary_emb, - **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 2688d3640ea5..e24a28fc3d7b 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -28,7 +28,7 @@ ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -329,10 +329,6 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.FloatTensor, @@ -404,24 +400,12 @@ def forward( is_skip = True if skip_layers is not None and index_block in skip_layers else False if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, joint_attention_kwargs, - **ckpt_kwargs, ) elif not is_skip: encoder_hidden_states, hidden_states = block( diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index 3b5aedb79e3c..5580d0f70f9f 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -343,19 +343,11 @@ def forward( # 2. Blocks for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( - block, - hidden_states, - None, - encoder_hidden_states, - None, - use_reentrant=False, + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, None, encoder_hidden_states, None ) else: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) hidden_states_mix = hidden_states hidden_states_mix = hidden_states_mix + emb diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 84a1322d2a95..5a7fc32223d6 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -248,10 +248,6 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index b4e0cea7c71d..3e7039a97ff5 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from ...utils import deprecate, is_torch_version, logging +from ...utils import deprecate, logging from ...utils.torch_utils import apply_freeu from ..activations import get_activation from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 @@ -737,24 +737,12 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} if attn is not None: hidden_states = attn(hidden_states, temb=temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states = self._gradient_checkpointing_func( + resnet, hidden_states, temb, - **ckpt_kwargs, ) else: if attn is not None: @@ -883,17 +871,6 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -902,12 +879,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = attn( hidden_states, @@ -1156,23 +1128,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states = output_states + (hidden_states,) else: @@ -1304,23 +1260,7 @@ def forward( for i, (resnet, attn) in enumerate(blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1418,21 +1358,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1906,21 +1832,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2058,17 +1970,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2153,21 +2055,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2262,22 +2150,10 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states = self._gradient_checkpointing_func( + resnet, hidden_states, temb, - **ckpt_kwargs, ) hidden_states = attn( hidden_states, @@ -2423,23 +2299,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn(hidden_states) else: hidden_states = resnet(hidden_states, temb) @@ -2588,23 +2448,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2721,21 +2565,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -3251,21 +3081,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -3409,17 +3225,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -3512,21 +3318,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -3640,22 +3432,10 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), + hidden_states = self._gradient_checkpointing_func( + resnet, hidden_states, temb, - **ckpt_kwargs, ) hidden_states = attn( hidden_states, diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 3447fa0674bc..5674d8ba26ec 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -834,10 +834,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 195f7601dd54..8d7614a23383 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,7 +17,7 @@ import torch from torch import nn -from ...utils import deprecate, is_torch_version, logging +from ...utils import deprecate, logging from ...utils.torch_utils import apply_freeu from ..attention import Attention from ..resnet import ( @@ -1078,31 +1078,14 @@ def forward( ) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) else: hidden_states = attn( hidden_states, @@ -1110,11 +1093,7 @@ def custom_forward(*inputs): image_only_indicator=image_only_indicator, return_dict=False, )[0] - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) return hidden_states @@ -1169,34 +1148,9 @@ def forward( output_states = () for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) else: - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) output_states = output_states + (hidden_states,) @@ -1281,25 +1235,8 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for resnet, attn in blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - **ckpt_kwargs, - ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) hidden_states = attn( hidden_states, @@ -1308,11 +1245,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1385,34 +1318,9 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) else: - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1495,25 +1403,8 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - image_only_indicator, - **ckpt_kwargs, - ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1521,11 +1412,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet( - hidden_states, - temb, - image_only_indicator=image_only_indicator, - ) + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 398609778e65..845d93b9db09 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -37,11 +37,7 @@ from ..modeling_utils import ModelMixin from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( - CrossAttnDownBlock3D, - CrossAttnUpBlock3D, - DownBlock3D, UNetMidBlock3DCrossAttn, - UpBlock3D, get_down_block, get_up_block, ) @@ -472,10 +468,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): - module.gradient_checkpointing = value - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index d5d98c256357..f0eca75de169 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -35,11 +35,7 @@ from ..modeling_utils import ModelMixin from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( - CrossAttnDownBlock3D, - CrossAttnUpBlock3D, - DownBlock3D, UNetMidBlock3DCrossAttn, - UpBlock3D, get_down_block, get_up_block, ) @@ -436,11 +432,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): - module.gradient_checkpointing = value - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py index f611e7d82b1d..73bf0020b481 100644 --- a/src/diffusers/models/unets/unet_kandinsky3.py +++ b/src/diffusers/models/unets/unet_kandinsky3.py @@ -205,10 +205,6 @@ def set_default_attn_processor(self): """ self.set_attn_processor(AttnProcessor()) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 1d0a38a8fb13..49be9e87520f 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import BaseOutput, deprecate, is_torch_version, logging +from ...utils import BaseOutput, deprecate, logging from ...utils.torch_utils import apply_freeu from ..attention import BasicTransformerBlock from ..attention_processor import ( @@ -324,25 +324,7 @@ def forward( blocks = zip(self.resnets, self.motion_modules) for resnet, motion_module in blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(input_tensor=hidden_states, temb=temb) @@ -514,23 +496,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) for i, (resnet, attn, motion_module) in enumerate(blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(input_tensor=hidden_states, temb=temb) @@ -543,10 +509,7 @@ def custom_forward(*inputs): return_dict=False, )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - ) + hidden_states = motion_module(hidden_states, num_frames=num_frames) # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -733,23 +696,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(input_tensor=hidden_states, temb=temb) @@ -762,10 +709,7 @@ def custom_forward(*inputs): return_dict=False, )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - ) + hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -896,24 +840,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(input_tensor=hidden_states, temb=temb) @@ -1080,34 +1007,10 @@ def forward( )[0] if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(motion_module, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - ) + hidden_states = motion_module(hidden_states, num_frames=num_frames) hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states @@ -1966,10 +1869,6 @@ def set_default_attn_processor(self) -> None: self.set_attn_processor(processor) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): - module.gradient_checkpointing = value - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 172c1e6bbb05..db4ace9656a3 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -320,10 +320,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 238e6b411356..f57754435fdc 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -387,9 +387,6 @@ def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=Tru self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, value=False): - self.gradient_checkpointing = value - def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) @@ -456,29 +453,18 @@ def _down_encode(self, x, r_embed, clip): block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - for down_block, downscaler, repmap in block_group: x = downscaler(x) for i in range(len(repmap) + 1): for block in down_block: if isinstance(block, SDCascadeResBlock): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + x = self._gradient_checkpointing_func(block, x) elif isinstance(block, SDCascadeAttnBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, clip, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, clip) elif isinstance(block, SDCascadeTimestepBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, r_embed, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, r_embed) else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False) + x = self._gradient_checkpointing_func(block) if i < len(repmap): x = repmap[i](x) level_outputs.insert(0, x) @@ -505,13 +491,6 @@ def _up_decode(self, level_outputs, r_embed, clip): block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - for i, (up_block, upscaler, repmap) in enumerate(block_group): for j in range(len(repmap) + 1): for k, block in enumerate(up_block): @@ -523,19 +502,13 @@ def custom_forward(*inputs): x.float(), skip.shape[-2:], mode="bilinear", align_corners=True ) x = x.to(orig_type) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, skip, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, skip) elif isinstance(block, SDCascadeAttnBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, clip, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, clip) elif isinstance(block, SDCascadeTimestepBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, r_embed, use_reentrant=False - ) + x = self._gradient_checkpointing_func(block, x, r_embed) else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + x = self._gradient_checkpointing_func(block, x) if j < len(repmap): x = repmap[j](x) x = upscaler(x) diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py index 785f0f30aaae..94b39c84f055 100644 --- a/src/diffusers/models/unets/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -148,9 +148,6 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - pass - def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index a33e26568772..00bed864ba34 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -38,7 +38,7 @@ from ...models.transformers.transformer_2d import Transformer2DModel from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from ...models.unets.unet_2d_condition import UNet2DConditionOutput -from ...utils import BaseOutput, is_torch_version, logging +from ...utils import BaseOutput, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -673,11 +673,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, sample: torch.Tensor, @@ -1114,23 +1109,7 @@ def forward( for i in range(num_layers): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnets[i]), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states @@ -1141,8 +1120,8 @@ def custom_forward(*inputs): else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], hidden_states, forward_encoder_hidden_states, None, # timestep @@ -1150,7 +1129,6 @@ def custom_forward(*inputs): cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, - **ckpt_kwargs, )[0] else: hidden_states = self.resnets[i](hidden_states, temb) @@ -1292,17 +1270,6 @@ def forward( for i in range(len(self.resnets[1:])): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states @@ -1313,8 +1280,8 @@ def custom_forward(*inputs): else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], hidden_states, forward_encoder_hidden_states, None, # timestep @@ -1322,14 +1289,8 @@ def custom_forward(*inputs): cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, - **ckpt_kwargs, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnets[i + 1]), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb) else: for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: @@ -1466,23 +1427,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.resnets[i]), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states @@ -1493,8 +1438,8 @@ def custom_forward(*inputs): else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], hidden_states, forward_encoder_hidden_states, None, # timestep @@ -1502,7 +1447,6 @@ def custom_forward(*inputs): cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, - **ckpt_kwargs, )[0] else: hidden_states = self.resnets[i](hidden_states, temb) diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py index 0d78b987ce77..d2408417f590 100644 --- a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py +++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -174,19 +174,16 @@ def forward( ) use_cache = False - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self._gradient_checkpointing_func( + layer_module, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, + query_length, ) else: layer_outputs = layer_module( diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 4d9e50e3a2b4..bc276811ff4a 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -34,7 +34,7 @@ from ....models.transformers.dual_transformer_2d import DualTransformer2DModel from ....models.transformers.transformer_2d import Transformer2DModel from ....models.unets.unet_2d_condition import UNet2DConditionOutput -from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import apply_freeu @@ -963,10 +963,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. @@ -1597,21 +1593,7 @@ def forward( for resnet in self.resnets: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -1734,23 +1716,7 @@ def forward( for i, (resnet, attn) in enumerate(blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1876,21 +1842,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = resnet(hidden_states, temb) @@ -2035,23 +1987,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2230,25 +2166,9 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} if attn is not None: hidden_states = attn(hidden_states, temb=temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: if attn is not None: hidden_states = attn(hidden_states, temb=temb) @@ -2377,17 +2297,6 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2396,12 +2305,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = attn( hidden_states, diff --git a/src/diffusers/pipelines/kolors/text_encoder.py b/src/diffusers/pipelines/kolors/text_encoder.py index 5eb8d4c43d02..f07d064cbc22 100644 --- a/src/diffusers/pipelines/kolors/text_encoder.py +++ b/src/diffusers/pipelines/kolors/text_encoder.py @@ -605,7 +605,7 @@ def forward( layer = self._get_layer(index) if torch.is_grad_enabled() and self.gradient_checkpointing: - layer_ret = torch.utils.checkpoint.checkpoint( + layer_ret = self._gradient_checkpointing_func( layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache ) else: @@ -666,10 +666,6 @@ def get_position_ids(self, input_ids, device): position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) return position_ids - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GLMTransformer): - module.gradient_checkpointing = value - def default_init(cls, *args, **kwargs): return cls(*args, **kwargs) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index d079e71fe38e..c7aa76a01fb8 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -544,10 +544,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LDMBertEncoder,)): - module.gradient_checkpointing = value - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -688,15 +684,8 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self._gradient_checkpointing_func( + encoder_layer, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index f90fc82a98ad..9863c506d743 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -29,7 +29,6 @@ AttnProcessor, ) from ...models.modeling_utils import ModelMixin -from ...utils import is_torch_version from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm @@ -138,9 +137,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 @@ -159,33 +155,13 @@ def forward(self, x, r, c): r_embed = self.gen_r_embedding(r) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - for block in self.blocks: - if isinstance(block, AttnBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, c_embed, use_reentrant=False - ) - elif isinstance(block, TimestepBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, r_embed, use_reentrant=False - ) - else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) - else: - for block in self.blocks: - if isinstance(block, AttnBlock): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed) - elif isinstance(block, TimestepBlock): - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed) - else: - x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x) + for block in self.blocks: + if isinstance(block, AttnBlock): + x = self._gradient_checkpointing_func(block, x, c_embed) + elif isinstance(block, TimestepBlock): + x = self._gradient_checkpointing_func(block, x, r_embed) + else: + x = self._gradient_checkpointing_func(block, x) else: for block in self.blocks: if isinstance(block, AttnBlock): From 066465e11273335121ebb1afb4c0e6c8d1bb3c5c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 25 Jan 2025 17:40:03 +0100 Subject: [PATCH 5/7] =?UTF-8?q?more=20cleanup=20=F0=9F=A7=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/community/matryoshka.py | 79 ++----------------- .../pixart/controlnet_pixart_alpha.py | 20 +---- 2 files changed, 7 insertions(+), 92 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 1d7a367ecc60..4895bd150114 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -80,7 +80,6 @@ USE_PEFT_BACKEND, BaseOutput, deprecate, - is_torch_version, is_torch_xla_available, logging, replace_example_docstring, @@ -869,23 +868,7 @@ def forward( for i, (resnet, attn) in enumerate(blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1030,17 +1013,6 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1049,12 +1021,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: hidden_states = attn( hidden_states, @@ -1192,23 +1159,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1282,10 +1233,6 @@ def __init__( ] ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -1365,19 +1312,8 @@ def forward( # Blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -1385,7 +1321,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, class_labels, - **ckpt_kwargs, ) else: hidden_states = block( @@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py index f825719a1364..8f2eb974398d 100644 --- a/examples/research_projects/pixart/controlnet_pixart_alpha.py +++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py @@ -8,7 +8,6 @@ from diffusers.models.attention import BasicTransformerBlock from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils.torch_utils import is_torch_version class PixArtControlNetAdapterBlock(nn.Module): @@ -151,10 +150,6 @@ def __init__( self.transformer = transformer self.controlnet = controlnet - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, @@ -220,18 +215,8 @@ def forward( print("Gradient checkpointing is not supported for the controlnet transformer model, yet.") exit(1) - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, attention_mask, encoder_hidden_states, @@ -239,7 +224,6 @@ def custom_forward(*inputs): timestep, cross_attention_kwargs, None, - **ckpt_kwargs, ) else: # the control nets are only used for the blocks 1 to self.blocks_num From 50d0a283e27c4b7d79ae4f5b537c927a911db4de Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 25 Jan 2025 17:42:10 +0100 Subject: [PATCH 6/7] make fix-copies --- src/diffusers/models/unets/unet_2d_blocks.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 3e7039a97ff5..e082d524e766 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -739,11 +739,7 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No if torch.is_grad_enabled() and self.gradient_checkpointing: if attn is not None: hidden_states = attn(hidden_states, temb=temb) - hidden_states = self._gradient_checkpointing_func( - resnet, - hidden_states, - temb, - ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: if attn is not None: hidden_states = attn(hidden_states, temb=temb) From de92b67c988b8d1d825e95cfac53e7a7f86d8866 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 25 Jan 2025 22:27:30 +0100 Subject: [PATCH 7/7] update test --- .../models/unets/unet_motion_model.py | 6 ++++-- tests/models/test_modeling_common.py | 21 ++++++------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 49be9e87520f..21e4db23a166 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1007,10 +1007,12 @@ def forward( )[0] if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(motion_module, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func( + motion_module, hidden_states, None, None, None, num_frames, None + ) hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: - hidden_states = motion_module(hidden_states, num_frames=num_frames) + hidden_states = motion_module(hidden_states, None, None, None, num_frames, None) hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 05050e05bb19..b88b6f16b9fb 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -953,24 +953,15 @@ def test_gradient_checkpointing_is_applied( init_dict["block_out_channels"] = block_out_channels model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - model = model_class_copy(**init_dict) model.enable_gradient_checkpointing() + modules_with_gc_enabled = {} + for submodule in model.modules(): + if hasattr(submodule, "gradient_checkpointing"): + self.assertTrue(submodule.gradient_checkpointing) + modules_with_gc_enabled[submodule.__class__.__name__] = True + assert set(modules_with_gc_enabled.keys()) == expected_set assert all(modules_with_gc_enabled.values()), "All modules should be enabled"