diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7766442f7133..84db0d061768 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1104,8 +1104,26 @@ def forward( accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights - hidden_states = torch.where( - num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # TODO(aryan): Maybe this could be done in a better way. + # + # Previously, this was: + # hidden_states = torch.where( + # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # ) + # + # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory + # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes + # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly + # looked into this deeply because other memory optimizations led to more pronounced reductions. + hidden_states = torch.cat( + [ + torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) + for accumulated_split, num_times_split in zip( + accumulated_values.split(self.context_length, dim=1), + num_times_accumulated.split(self.context_length, dim=1), + ) + ], + dim=1, ).to(dtype) # 3. Feed-forward diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 89cdb76741f7..6125feba5899 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -187,12 +187,12 @@ def forward( hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(input=hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( - hidden_states, + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, @@ -200,7 +200,7 @@ def forward( ) # 3. Output - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(input=hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) @@ -344,7 +344,7 @@ def custom_forward(*inputs): ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -352,7 +352,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -531,25 +531,18 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -563,7 +556,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -757,25 +750,18 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -783,7 +769,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -929,13 +915,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -1080,10 +1066,19 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1096,14 +1091,6 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(motion_module), hidden_states, @@ -1117,19 +1104,11 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index f2763f1c33cc..dc0071a494e3 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock +from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..models.transformers.transformer_2d import Transformer2DModel from ..models.unets.unet_motion_model import ( + AnimateDiffTransformer3D, CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, @@ -30,6 +34,114 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class SplitInferenceModule(nn.Module): + r""" + A wrapper module class that splits inputs along a specified dimension before performing a forward pass. + + This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking + them into smaller chunks, processing each chunk separately, and then reassembling the results. + + Args: + module (`nn.Module`): + The underlying PyTorch module that will be applied to each chunk of split inputs. + split_size (`int`, defaults to `1`): + The size of each chunk after splitting the input tensor. + split_dim (`int`, defaults to `0`): + The dimension along which the input tensors are split. + input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`): + A list of keyword arguments (strings) that represent the input tensors to be split. + + Workflow: + 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using + `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. + 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments + that were passed. + 3. The output tensors from each split are concatenated back together along `split_dim` before returning. + + Example: + ```python + >>> import torch + >>> import torch.nn as nn + + >>> model = nn.Linear(1000, 1000) + >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"]) + + >>> input_tensor = torch.randn(42, 1000) + >>> # Will split the tensor into 21 slices of shape [2, 1000]. + >>> output = split_module(input=input_tensor) + ``` + + It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex + multi-dimensional splitting. + """ + + def __init__( + self, + module: nn.Module, + split_size: int = 1, + split_dim: int = 0, + input_kwargs_to_split: List[str] = ["hidden_states"], + ) -> None: + super().__init__() + + self.module = module + self.split_size = split_size + self.split_dim = split_dim + self.input_kwargs_to_split = set(input_kwargs_to_split) + + def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + r"""Forward method for the `SplitInferenceModule`. + + This method processes the input by splitting specified keyword arguments along a given dimension, running the + underlying module on each split, and then concatenating the results. The splitting is controlled by the + `split_size` and `split_dim` parameters specified during initialization. + + Args: + *args (`Any`): + Positional arguments that are passed directly to the `module` without modification. + **kwargs (`Dict[str, torch.Tensor]`): + Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the + entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword + arguments are passed unchanged. + + Returns: + `Union[torch.Tensor, Tuple[torch.Tensor]]`: + The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred + without it. + - If the underlying module returns a single tensor, the result will be a single concatenated tensor + along the same `split_dim` after processing all splits. + - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated + along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors. + """ + split_inputs = {} + + # 1. Split inputs that were specified during initialization and also present in passed kwargs + for key in list(kwargs.keys()): + if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): + continue + split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) + kwargs.pop(key) + + # 2. Invoke forward pass across each split + results = [] + for split_input in zip(*split_inputs.values()): + inputs = dict(zip(split_inputs.keys(), split_input)) + inputs.update(kwargs) + + intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) + results.append(intermediate_tensor_or_tensor_tuple) + + # 3. Concatenate split restuls to obtain final outputs + if isinstance(results[0], torch.Tensor): + return torch.cat(results, dim=self.split_dim) + elif isinstance(results[0], tuple): + return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)]) + else: + raise ValueError( + "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." + ) + + class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" @@ -70,6 +182,9 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow motion_module.transformer_blocks[i].load_state_dict( basic_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim + ) def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" @@ -98,6 +213,9 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim + ) def _check_inputs_free_noise( self, @@ -410,6 +528,69 @@ def disable_free_noise(self) -> None: for block in blocks: self._disable_free_noise_in_block(block) + def _enable_split_inference_motion_modules_( + self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int + ) -> None: + for motion_module in motion_modules: + motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"]) + + for i in range(len(motion_module.transformer_blocks)): + motion_module.transformer_blocks[i] = SplitInferenceModule( + motion_module.transformer_blocks[i], + spatial_split_size, + 0, + ["hidden_states", "encoder_hidden_states"], + ) + + motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"]) + + def _enable_split_inference_attentions_( + self, attentions: List[Transformer2DModel], temporal_split_size: int + ) -> None: + for i in range(len(attentions)): + attentions[i] = SplitInferenceModule( + attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"] + ) + + def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None: + for i in range(len(resnets)): + resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"]) + + def _enable_split_inference_samplers_( + self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int + ) -> None: + for i in range(len(samplers)): + samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) + + def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: + r""" + Enable FreeNoise memory optimizations by utilizing + [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks. + + Args: + spatial_split_size (`int`, defaults to `256`): + The split size across spatial dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion + modeling blocks. + temporal_split_size (`int`, defaults to `16`): + The split size across temporal dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial + attention, resnets, downsampling and upsampling blocks. + """ + # TODO(aryan): Discuss on what's the best way to provide more control to users + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + if getattr(block, "motion_modules", None) is not None: + self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size) + if getattr(block, "attentions", None) is not None: + self._enable_split_inference_attentions_(block.attentions, temporal_split_size) + if getattr(block, "resnets", None) is not None: + self._enable_split_inference_resnets_(block.resnets, temporal_split_size) + if getattr(block, "downsamplers", None) is not None: + self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size) + if getattr(block, "upsamplers", None) is not None: + self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size) + @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 677267305373..54c83d6a1b68 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -460,6 +460,30 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_split_inference(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise(8, 4) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + # Test FreeNoise with split inference memory-optimization + pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) + + inputs_enable_split_inference = self.get_dummy_inputs(torch_device) + frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] + + sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() + self.assertLess( + sum_split_inference, + 1e-4, + "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", + ) + def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffPipeline = self.pipeline_class(**components) diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 59146115b90a..c3fd4c73736a 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -492,6 +492,34 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_split_inference(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise(8, 4) + + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_normal["num_inference_steps"] = 2 + inputs_normal["strength"] = 0.5 + frames_normal = pipe(**inputs_normal).frames[0] + + # Test FreeNoise with split inference memory-optimization + pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) + + inputs_enable_split_inference = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_enable_split_inference["num_inference_steps"] = 2 + inputs_enable_split_inference["strength"] = 0.5 + frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] + + sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() + self.assertLess( + sum_split_inference, + 1e-4, + "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", + ) + def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)