Skip to content

Refactor gradient checkpointing #10611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 28, 2025
79 changes: 5 additions & 74 deletions examples/community/matryoshka.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
is_torch_version,
is_torch_xla_available,
logging,
replace_example_docstring,
Expand Down Expand Up @@ -869,23 +868,7 @@ def forward(

for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -1030,17 +1013,6 @@ def forward(
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand All @@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
Expand Down Expand Up @@ -1192,23 +1159,7 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -1282,10 +1233,6 @@ def __init__(
]
)

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -1365,27 +1312,15 @@ def forward(
# Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
Expand Down Expand Up @@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.

Expand Down
20 changes: 2 additions & 18 deletions examples/research_projects/pixart/controlnet_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.torch_utils import is_torch_version


class PixArtControlNetAdapterBlock(nn.Module):
Expand Down Expand Up @@ -151,10 +150,6 @@ def __init__(
self.transformer = transformer
self.controlnet = controlnet

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -220,26 +215,15 @@ def forward(
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
Expand Down
4 changes: 0 additions & 4 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,6 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value

def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand Down
26 changes: 4 additions & 22 deletions src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,19 +507,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = sample + residual

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# Down blocks
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
sample = self._gradient_checkpointing_func(down_block, sample)

# Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = self._gradient_checkpointing_func(self.mid_block, sample)
else:
# Down blocks
for down_block in self.down_blocks:
Expand Down Expand Up @@ -647,19 +640,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = self._gradient_checkpointing_func(self.mid_block, sample)

# Up blocks
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
sample = self._gradient_checkpointing_func(up_block, sample)

else:
# Mid block
Expand Down Expand Up @@ -809,10 +795,6 @@ def __init__(
sample_size - self.tile_overlap_w,
)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
module.gradient_checkpointing = value

def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand Down
67 changes: 14 additions & 53 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,15 +421,8 @@ def forward(
conv_cache_key = f"resnet_{i}"

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)

return create_forward

hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
temb,
zq,
Expand Down Expand Up @@ -523,15 +516,8 @@ def forward(
conv_cache_key = f"resnet_{i}"

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)

return create_forward

hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
Expand Down Expand Up @@ -637,15 +623,8 @@ def forward(
conv_cache_key = f"resnet_{i}"

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)

return create_forward

hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
temb,
zq,
Expand Down Expand Up @@ -774,27 +753,20 @@ def forward(
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
down_block,
hidden_states,
temb,
None,
conv_cache.get(conv_cache_key),
)

# 2. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
self.mid_block,
hidden_states,
temb,
None,
Expand Down Expand Up @@ -940,16 +912,9 @@ def forward(
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# 1. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
self.mid_block,
hidden_states,
temb,
sample,
Expand All @@ -959,8 +924,8 @@ def custom_forward(*inputs):
# 2. Up
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
up_block,
hidden_states,
temb,
sample,
Expand Down Expand Up @@ -1122,10 +1087,6 @@ def __init__(
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value

def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
Expand Down
Loading