From bb3721a1eff941e3fabebd90f5b31da774050713 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 9 Dec 2022 16:58:19 +0000 Subject: [PATCH 01/19] first proposal --- src/diffusers/models/attention.py | 202 +++++------------- .../models/cross_attention_processors.py | 148 +++++++++++++ src/diffusers/models/unet_2d_blocks.py | 15 +- src/diffusers/models/unet_2d_condition.py | 9 +- 4 files changed, 216 insertions(+), 158 deletions(-) create mode 100644 src/diffusers/models/cross_attention_processors.py diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8b855a5ed5f4..7144a0fde59d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import warnings from dataclasses import dataclass from typing import Optional @@ -25,6 +24,7 @@ from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput from ..utils.import_utils import is_xformers_available +from .cross_attention_processors import CrossAttentionProcMixin, CrossAttentionProc, XFormersCrossAttentionProc, SlicedAttentionProc @dataclass @@ -176,7 +176,7 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, cross_attention_inputs=None, return_dict: bool = True): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. @@ -214,7 +214,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu # 2. Blocks for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep, cross_attention_inputs=cross_attention_inputs) # 3. Output if self.is_input_continuous: @@ -448,49 +448,23 @@ def __init__( self.norm3 = nn.LayerNorm(dim) # if xformers is installed try to use memory_efficient_attention by default - if is_xformers_available(): - try: - self.set_use_memory_efficient_attention_xformers(True) - except Exception as e: - warnings.warn( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) - - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - if not is_xformers_available(): - print("Here is how to install it") - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" - " available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - - def forward(self, hidden_states, context=None, timestep=None): +# if is_xformers_available(): +# try: +# self.set_use_memory_efficient_attention_xformers(True) +# except Exception as e: +# warnings.warn( +# "Could not enable memory efficient attention. Make sure xformers is installed" +# f" correctly and a GPU is available: {e}" +# ) + + def forward(self, hidden_states, context=None, timestep=None, cross_attention_inputs=None): # 1. Self-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) if self.only_cross_attention: - hidden_states = self.attn1(norm_hidden_states, context) + hidden_states + hidden_states = self.attn1(norm_hidden_states, context=context, cross_attention_inputs=cross_attention_inputs) + hidden_states else: hidden_states = self.attn1(norm_hidden_states) + hidden_states @@ -499,7 +473,7 @@ def forward(self, hidden_states, context=None, timestep=None): norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + hidden_states = self.attn2(norm_hidden_states, context=context, cross_attention_inputs=cross_attention_inputs) + hidden_states # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states @@ -544,7 +518,6 @@ def __init__( # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self._slice_size = None - self._use_memory_efficient_attention_xformers = False self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) @@ -554,127 +527,56 @@ def __init__( self.to_out.append(nn.Linear(inner_dim, query_dim)) self.to_out.append(nn.Dropout(dropout)) - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor + self.attn_proc = CrossAttentionProc(self.heads, self.upcast_attention) - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn_fn = XFormersCrossAttentionProc(self.heads, self.upcast_attention) def set_attention_slice(self, slice_size): if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") self._slice_size = slice_size - - def forward(self, hidden_states, context=None, mask=None): - batch_size, sequence_length, _ = hidden_states.shape - - query = self.to_q(hidden_states) - context = context if context is not None else hidden_states - key = self.to_k(context) - value = self.to_v(context) - - dim = query.shape[-1] - - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - # TODO(PVP) - mask is currently never used. Remember to re-implement when used - - # attention, what we cannot get enough of - if self._use_memory_efficient_attention_xformers: - hidden_states = self._memory_efficient_attention_xformers(query, key, value) - # Some versions of xformers return output in fp32, cast it back to the dtype of the input - hidden_states = hidden_states.to(query.dtype) - else: - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) - else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) - + self.attn_fn = SlicedAttentionProc(self.heads, self.upcast_attention) + + def set_attn_proc(self, attn_proc: CrossAttentionProcMixin): + if not isinstance(attn_proc, CrossAttentionProcMixin): + subclass = attn_proc.__bases__ if hasattr(attn_proc, "__bases__") else None + raise ValueError(f"`attn_proc` should be a subclass of {CrossAttentionProc}, but is of type {type(attn_proc)} and a subclass of {subclass}.") + self.attn_proc = attn_proc + + def forward(self, hidden_states, context=None, cross_attention_inputs=None): + # attn + cross_attention_inputs = cross_attention_inputs if cross_attention_inputs is not None else {} + hidden_states = self.attn(hidden_states, self.to_q, self.to_k, self.to_v, context=context, **cross_attention_inputs) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) return hidden_states - def _attention(self, query, key, value): - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - - # cast back to the original dtype - attention_probs = attention_probs.to(value.dtype) - - # compute attention output - hidden_states = torch.bmm(attention_probs, value) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def _sliced_attention(self, query, key, value, sequence_length, dim): - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - - if self.upcast_attention: - query_slice = query_slice.float() - key_slice = key_slice.float() - - attn_slice = torch.baddbmm( - torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), - query_slice, - key_slice.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attn_slice = attn_slice.softmax(dim=-1) - - # cast back to the original dtype - attn_slice = attn_slice.to(value.dtype) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def _memory_efficient_attention_xformers(self, query, key, value): - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - class FeedForward(nn.Module): r""" diff --git a/src/diffusers/models/cross_attention_processors.py b/src/diffusers/models/cross_attention_processors.py new file mode 100644 index 000000000000..bab6a8e9dc2e --- /dev/null +++ b/src/diffusers/models/cross_attention_processors.py @@ -0,0 +1,148 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from ..utils.import_utils import is_xformers_available + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class CrossAttentionProcMixin: + + def __init__(self, head_size, upcast_attention): + self.head_size = head_size + self.upcast_attention = upcast_attention + + def __call__(self, hidden_states, query_proj, key_proj, value_proj, context=None): + raise NotImplementedError("Make sure this method is overwritten in the subclass.") + + def batch_to_head_dim(self, tensor, head_size): + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // self.head_size, self.head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // self.head_size, seq_len, dim * self.head_size) + return tensor + + def head_to_batch_dim(self, tensor, head_size): + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, self.head_size, dim // self.head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * self.head_size, seq_len, dim // self.head_size) + return tensor + + def get_attention_scores(self, query, key): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype) + + return attention_probs + + +class CrossAttentionProc(CrossAttentionProcMixin): + + def __call__(self, hidden_states, query_proj, key_proj, value_proj, context=None): + batch_size, sequence_length, _ = hidden_states.shape + query = query_proj(hidden_states) + + context = context if context is not None else hidden_states + key = key_proj(context) + value = self.value_proj(context) + + query = self.head_to_batch_dim(query, self.head_size) + key = self.head_to_batch_dim(key, self.head_size) + value = self.head_to_batch_dim(value, self.head_size) + + attention_probs = self.get_attention_scores(query, key) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = self.batch_to_head_dim(hidden_states) + + return hidden_states + + +class XFormersCrossAttentionProc(CrossAttentionProcMixin): + + def __call__(self, hidden_states, query_proj, key_proj, value_proj, context=None): + batch_size, sequence_length, _ = hidden_states.shape + query = query_proj(hidden_states) + + context = context if context is not None else hidden_states + key = key_proj(context) + value = self.value_proj(context) + + query = self.head_to_batch_dim(query, self.head_size).contiguous() + key = self.head_to_batch_dim(key, self.head_size).contiguous() + value = self.head_to_batch_dim(value, self.head_size).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + hidden_states = self.batch_to_head_dim(hidden_states) + + return hidden_states + + +class SlicedAttentionProc(CrossAttentionProcMixin): + + def __init__(self, head_size, upcast_attention, slice_size): + super().__init__(head_size=head_size, upcast_attention=upcast_attention) + + self.slice_size = self.slice_size + + def __call__(self, hidden_states, query_proj, key_proj, value_proj, context=None): + batch_size, sequence_length, _ = hidden_states.shape + query = query_proj(hidden_states) + + dim = query.shape[-1] + + context = context if context is not None else hidden_states + key = key_proj(context) + value = self.value_proj(context) + + query = self.head_to_batch_dim(query, self.head_size) + key = self.head_to_batch_dim(key, self.head_size) + value = self.head_to_batch_dim(value, self.head_size) + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + + for i in range(hidden_states.shape[0] // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + attn_slice = self.get_attention_scores(query_slice, key_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = self.batch_to_head_dim(hidden_states) + + return hidden_states diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index aa8d4c9849e1..465fa6432558 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -408,10 +408,10 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, cross_attention_inputs=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_inputs=cross_attention_inputs).sample hidden_states = resnet(hidden_states, temb) return hidden_states @@ -588,7 +588,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, cross_attention_inputs=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -605,11 +605,11 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, cross_attention_inputs )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_inputs=cross_attention_inputs).sample output_states += (hidden_states,) @@ -1175,6 +1175,7 @@ def forward( res_hidden_states_tuple, temb=None, encoder_hidden_states=None, + cross_attention_inputs=None, upsample_size=None, ): for resnet, attn in zip(self.resnets, self.attentions): @@ -1196,11 +1197,11 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, cross_attention_inputs )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_inputs=cross_attention_inputs).sample if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0cfb15224982..fbb968d5def6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -232,6 +232,10 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + def set_cross_attention_class(self, cross_attention_cls): + # set recursively + + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -307,6 +311,7 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, + cross_attention_inputs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -382,6 +387,7 @@ def forward( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, + cross_attention_inputs=cross_attention_inputs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -389,7 +395,7 @@ def forward( down_block_res_samples += res_samples # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, cross_attention_inputs=cross_attention_inputs) # 5. up for i, upsample_block in enumerate(self.up_blocks): @@ -409,6 +415,7 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + cross_attention_inputs=cross_attention_inputs, upsample_size=upsample_size, ) else: From 2cf2902bd7610a78e74ce06c2ea56d7ef87dba58 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 9 Dec 2022 17:45:30 +0000 Subject: [PATCH 02/19] rename --- src/diffusers/models/attention.py | 18 +++++++++--------- src/diffusers/models/unet_2d_blocks.py | 16 ++++++++-------- src/diffusers/models/unet_2d_condition.py | 10 +++++----- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7144a0fde59d..4b44986c1fb9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -176,7 +176,7 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, cross_attention_inputs=None, return_dict: bool = True): + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, cross_attention_kwargs=None, return_dict: bool = True): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. @@ -214,7 +214,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, cros # 2. Blocks for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep, cross_attention_inputs=cross_attention_inputs) + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs) # 3. Output if self.is_input_continuous: @@ -457,14 +457,14 @@ def __init__( # f" correctly and a GPU is available: {e}" # ) - def forward(self, hidden_states, context=None, timestep=None, cross_attention_inputs=None): + def forward(self, hidden_states, context=None, timestep=None, cross_attention_kwargs=None): # 1. Self-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) if self.only_cross_attention: - hidden_states = self.attn1(norm_hidden_states, context=context, cross_attention_inputs=cross_attention_inputs) + hidden_states + hidden_states = self.attn1(norm_hidden_states, context=context, cross_attention_kwargs=cross_attention_kwargs) + hidden_states else: hidden_states = self.attn1(norm_hidden_states) + hidden_states @@ -473,7 +473,7 @@ def forward(self, hidden_states, context=None, timestep=None, cross_attention_in norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = self.attn2(norm_hidden_states, context=context, cross_attention_inputs=cross_attention_inputs) + hidden_states + hidden_states = self.attn2(norm_hidden_states, context=context, cross_attention_kwargs=cross_attention_kwargs) + hidden_states # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states @@ -561,16 +561,16 @@ def set_attention_slice(self, slice_size): self._slice_size = slice_size self.attn_fn = SlicedAttentionProc(self.heads, self.upcast_attention) - def set_attn_proc(self, attn_proc: CrossAttentionProcMixin): + def set_cross_attn_proc(self, attn_proc: CrossAttentionProcMixin): if not isinstance(attn_proc, CrossAttentionProcMixin): subclass = attn_proc.__bases__ if hasattr(attn_proc, "__bases__") else None raise ValueError(f"`attn_proc` should be a subclass of {CrossAttentionProc}, but is of type {type(attn_proc)} and a subclass of {subclass}.") self.attn_proc = attn_proc - def forward(self, hidden_states, context=None, cross_attention_inputs=None): + def forward(self, hidden_states, context=None, cross_attention_kwargs=None): # attn - cross_attention_inputs = cross_attention_inputs if cross_attention_inputs is not None else {} - hidden_states = self.attn(hidden_states, self.to_q, self.to_k, self.to_v, context=context, **cross_attention_inputs) + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + hidden_states = self.attn(hidden_states, self.to_q, self.to_k, self.to_v, context=context, **cross_attention_kwargs) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 465fa6432558..fbac62764016 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -408,10 +408,10 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, cross_attention_inputs=None): + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, cross_attention_kwargs=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_inputs=cross_attention_inputs).sample + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs).sample hidden_states = resnet(hidden_states, temb) return hidden_states @@ -588,7 +588,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, cross_attention_inputs=None): + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, cross_attention_kwargs=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -605,11 +605,11 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, cross_attention_inputs + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, cross_attention_kwargs )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_inputs=cross_attention_inputs).sample + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs).sample output_states += (hidden_states,) @@ -1175,7 +1175,7 @@ def forward( res_hidden_states_tuple, temb=None, encoder_hidden_states=None, - cross_attention_inputs=None, + cross_attention_kwargs=None, upsample_size=None, ): for resnet, attn in zip(self.resnets, self.attentions): @@ -1197,11 +1197,11 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, cross_attention_inputs + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, cross_attention_kwargs )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_inputs=cross_attention_inputs).sample + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs).sample if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index fbb968d5def6..89a42e7ed847 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -232,7 +232,7 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - def set_cross_attention_class(self, cross_attention_cls): + def set_cross_attention_processor(self, cross_attention_cls): # set recursively @@ -311,7 +311,7 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, - cross_attention_inputs: Optional[Dict[str, Any]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -387,7 +387,7 @@ def forward( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, - cross_attention_inputs=cross_attention_inputs, + cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -395,7 +395,7 @@ def forward( down_block_res_samples += res_samples # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, cross_attention_inputs=cross_attention_inputs) + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs) # 5. up for i, upsample_block in enumerate(self.up_blocks): @@ -415,7 +415,7 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, - cross_attention_inputs=cross_attention_inputs, + cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, ) else: From 4e981587bed2aa4e73a3d6b96066ede8c4ebb32f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 9 Dec 2022 17:49:21 +0000 Subject: [PATCH 03/19] up --- src/diffusers/models/unet_2d_condition.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 89a42e7ed847..97c03959054e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, Any import torch import torch.nn as nn @@ -22,6 +22,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging from .embeddings import TimestepEmbedding, Timesteps +from .cross_attention_processors import CrossAttentionProcMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, @@ -232,9 +233,17 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - def set_cross_attention_processor(self, cross_attention_cls): + def set_cross_attention_processor(self, cross_attention_cls: CrossAttentionProcMixin): # set recursively + def fn_recursive_set_cross_attn_proc(module: torch.nn.Module): + if hasattr(module, "set_cross_attn_proc"): + module.set_cross_attn_proc(cross_attention_cls) + for child in module.children(): + fn_recursive_set_cross_attn_proc(child) + + for module in self.children(): + fn_recursive_set_cross_attn_proc(module) def set_attention_slice(self, slice_size): r""" From 92a0d017d35d5f092c9ee71e9369e1b86d9ad8a5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 12 Dec 2022 12:54:50 +0100 Subject: [PATCH 04/19] Apply suggestions from code review --- src/diffusers/models/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4b44986c1fb9..fdca03aa5438 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -559,7 +559,7 @@ def set_attention_slice(self, slice_size): raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") self._slice_size = slice_size - self.attn_fn = SlicedAttentionProc(self.heads, self.upcast_attention) + self.attn_proc = SlicedAttentionProc(self.heads, self.upcast_attention) def set_cross_attn_proc(self, attn_proc: CrossAttentionProcMixin): if not isinstance(attn_proc, CrossAttentionProcMixin): @@ -570,7 +570,7 @@ def set_cross_attn_proc(self, attn_proc: CrossAttentionProcMixin): def forward(self, hidden_states, context=None, cross_attention_kwargs=None): # attn cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - hidden_states = self.attn(hidden_states, self.to_q, self.to_k, self.to_v, context=context, **cross_attention_kwargs) + hidden_states = self.attn_proc(hidden_states, self.to_q, self.to_k, self.to_v, context=context, **cross_attention_kwargs) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout From 4d72931bcaa5529b481c411a5a366d8226716000 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 13:54:38 +0100 Subject: [PATCH 05/19] better --- src/diffusers/models/attention.py | 39 ++++++++++++++++--- .../models/cross_attention_processors.py | 17 +++++--- src/diffusers/models/unet_2d_blocks.py | 32 ++++++++++++--- src/diffusers/models/unet_2d_condition.py | 10 +++-- 4 files changed, 79 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2034e52c451a..1c1f4e0fe945 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -176,7 +176,14 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, cross_attention_kwargs=None, return_dict: bool = True): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. @@ -214,7 +221,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, cros # 2. Blocks for block in self.transformer_blocks: - hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs) + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + ) # 3. Output if self.is_input_continuous: @@ -427,12 +439,24 @@ def __init__( # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, cross_attention_kwargs=None): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + ): # 1. Self-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) - attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) hidden_states = attn_output + hidden_states if self.attn2 is not None: @@ -440,7 +464,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, atte norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) hidden_states = attn_output + hidden_states # 3. Feed-forward diff --git a/src/diffusers/models/cross_attention_processors.py b/src/diffusers/models/cross_attention_processors.py index 50837b7d181e..b6090e1ee2c0 100644 --- a/src/diffusers/models/cross_attention_processors.py +++ b/src/diffusers/models/cross_attention_processors.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union + import torch -from torch import nn -from typing import Union, Optional import torch.nn.functional as F +from torch import nn from ..utils.import_utils import is_xformers_available @@ -311,7 +312,7 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] - attn_slice = attn.get_attention_scores(query_slice, key_slice, attention_mask[start_idx: end_idx]) + attn_slice = attn.get_attention_scores(query_slice, key_slice, attention_mask[start_idx:end_idx]) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) @@ -367,7 +368,7 @@ def forward(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=N query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] - attn_slice = attn.get_attention_scores(query_slice, key_slice, attention_mask[start_idx: end_idx]) + attn_slice = attn.get_attention_scores(query_slice, key_slice, attention_mask[start_idx:end_idx]) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) @@ -383,4 +384,10 @@ def forward(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=N return hidden_states -AttnProcessor = Union[CrossAttnProcessor, XFormersCrossAttnProcessor, SlicedAttnProcessor, CrossAttnAddedKVProcessor, SlicedAttnAddedKVProcessor] +AttnProcessor = Union[ + CrossAttnProcessor, + XFormersCrossAttnProcessor, + SlicedAttnProcessor, + CrossAttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, +] diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index fdfaca3f73d1..1d5f2ea16fd7 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,8 @@ import torch from torch import nn -from .attention import AttentionBlock, CrossAttention, DualTransformer2DModel, Transformer2DModel +from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel +from .cross_attention_processors import CrossAttention, CrossAttnAddedKVProcessor from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -481,10 +482,16 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample hidden_states = resnet(hidden_states, temb) return hidden_states @@ -543,6 +550,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + processor=CrossAttnAddedKVProcessor() ) ) resnets.append( @@ -749,7 +757,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): # TODO(Patrick, William) - attention mask is not used output_states = () @@ -774,7 +784,11 @@ def custom_forward(*inputs): )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample output_states += (hidden_states,) @@ -1310,6 +1324,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + processor=CrossAttnAddedKVProcessor() ) ) self.attentions = nn.ModuleList(attentions) @@ -1562,7 +1577,11 @@ def custom_forward(*inputs): )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2115,6 +2134,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + processor=CrossAttnAddedKVProcessor() ) ) self.attentions = nn.ModuleList(attentions) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f1d54a11b6fd..f347c2c0a776 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Tuple, Union, Dict, Any +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -21,8 +21,8 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging -from .embeddings import TimestepEmbedding, Timesteps from .cross_attention_processors import CrossAttentionProcMixin +from .embeddings import TimestepEmbedding, Timesteps from .unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, @@ -449,7 +449,11 @@ def forward( # 4. mid sample = self.mid_block( - sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, ) # 5. up From c2a0b4d60b562526cabbf4abdd6c7d52544910f9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 14:07:56 +0100 Subject: [PATCH 06/19] up --- src/diffusers/models/attention.py | 1 + src/diffusers/models/cross_attention_processors.py | 1 + src/diffusers/models/unet_2d_condition.py | 14 +++++++------- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1c1f4e0fe945..2d6e4358be70 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -451,6 +451,7 @@ def forward( norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/src/diffusers/models/cross_attention_processors.py b/src/diffusers/models/cross_attention_processors.py index b6090e1ee2c0..5a5a667cbd1c 100644 --- a/src/diffusers/models/cross_attention_processors.py +++ b/src/diffusers/models/cross_attention_processors.py @@ -191,6 +191,7 @@ def prepare_attention_mask(self, attention_mask, target_length): class CrossAttnProcessor: def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape + import ipdb; ipdb.set_trace() attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f347c2c0a776..86c3b0e09bbe 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging -from .cross_attention_processors import CrossAttentionProcMixin +from .cross_attention_processors import AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -266,17 +266,17 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - def set_cross_attention_processor(self, cross_attention_cls: CrossAttentionProcMixin): + def set_attn_processor(self, processor: AttnProcessor): # set recursively - def fn_recursive_set_cross_attn_proc(module: torch.nn.Module): - if hasattr(module, "set_cross_attn_proc"): - module.set_cross_attn_proc(cross_attention_cls) + def fn_recursive_attn_processor(module: torch.nn.Module): + if hasattr(module, "set_processor"): + module.set_processor(processor) for child in module.children(): - fn_recursive_set_cross_attn_proc(child) + fn_recursive_attn_processor(child) for module in self.children(): - fn_recursive_set_cross_attn_proc(module) + fn_recursive_attn_processor(module) def set_attention_slice(self, slice_size): r""" From a4a2b934179afd1c446af9e91910649fbefe5e5f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 14:25:46 +0100 Subject: [PATCH 07/19] finish --- src/diffusers/models/attention.py | 27 +++++++- .../models/cross_attention_processors.py | 37 ++++++----- src/diffusers/models/unet_2d_blocks.py | 33 +++++----- .../versatile_diffusion/modeling_text_unet.py | 62 +++++++++++++++---- 4 files changed, 115 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2d6e4358be70..5c0f053e5db5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -314,6 +314,31 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + def forward(self, hidden_states): residual = hidden_states batch, channel, height, width = hidden_states.shape @@ -454,7 +479,7 @@ def forward( cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) diff --git a/src/diffusers/models/cross_attention_processors.py b/src/diffusers/models/cross_attention_processors.py index 5a5a667cbd1c..f807b7c23bab 100644 --- a/src/diffusers/models/cross_attention_processors.py +++ b/src/diffusers/models/cross_attention_processors.py @@ -191,8 +191,6 @@ def prepare_attention_mask(self, attention_mask, target_length): class CrossAttnProcessor: def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape - import ipdb; ipdb.set_trace() - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) @@ -218,7 +216,9 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= class CrossAttnAddedKVProcessor: - def forward(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) @@ -226,17 +226,17 @@ def forward(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=N hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) - query = attn.reshape_heads_to_batch_dim(query) + query = attn.head_to_batch_dim(query) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - key = attn.reshape_heads_to_batch_dim(key) - value = attn.reshape_heads_to_batch_dim(value) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) - encoder_hidden_states_value_proj = attn.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) @@ -250,6 +250,9 @@ def forward(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=N # dropout hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states @@ -333,7 +336,10 @@ class SlicedAttnAddedKVProcessor: def __init__(self, slice_size): self.slice_size = self.slice_size - def forward(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) @@ -342,17 +348,17 @@ def forward(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=N query = attn.to_q(hidden_states) dim = query.shape[-1] - query = attn.reshape_heads_to_batch_dim(query) + query = attn.head_to_batch_dim(query) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - key = attn.reshape_heads_to_batch_dim(key) - value = attn.reshape_heads_to_batch_dim(value) - encoder_hidden_states_key_proj = attn.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) - encoder_hidden_states_value_proj = attn.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) @@ -382,6 +388,9 @@ def forward(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=N # dropout hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 1d5f2ea16fd7..6dbfe1a8033b 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -550,7 +550,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=CrossAttnAddedKVProcessor() + processor=CrossAttnAddedKVProcessor(), ) ) resnets.append( @@ -571,19 +571,19 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states.transpose(1, 2), attention_mask=attention_mask, + **cross_attention_kwargs, ) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual # resnet hidden_states = resnet(hidden_states, temb) @@ -1324,7 +1324,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=CrossAttnAddedKVProcessor() + processor=CrossAttnAddedKVProcessor(), ) ) self.attentions = nn.ModuleList(attentions) @@ -1353,23 +1353,23 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): output_states = () + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): # resnet hidden_states = resnet(hidden_states, temb) # attn - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states.transpose(1, 2), attention_mask=attention_mask, + **cross_attention_kwargs, ) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual output_states += (hidden_states,) @@ -2134,7 +2134,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=CrossAttnAddedKVProcessor() + processor=CrossAttnAddedKVProcessor(), ) ) self.attentions = nn.ModuleList(attentions) @@ -2171,7 +2171,9 @@ def forward( encoder_hidden_states=None, upsample_size=None, attention_mask=None, + cross_attention_kwargs=None, ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states @@ -2182,15 +2184,12 @@ def forward( hidden_states = resnet(hidden_states, temb) # attn - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states.transpose(1, 2), attention_mask=attention_mask, + **cross_attention_kwargs, ) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index c83a347fe93b..5b5a5e9cbc6e 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -351,6 +351,18 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) + def set_attn_processor(self, processor: AttnProcessor): + # set recursively + def fn_recursive_attn_processor(module: torch.nn.Module): + if hasattr(module, "set_processor"): + module.set_processor(processor) + + for child in module.children(): + fn_recursive_attn_processor(child) + + for module in self.children(): + fn_recursive_attn_processor(module) + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -427,6 +439,7 @@ def forward( encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -512,6 +525,7 @@ def forward( temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -520,7 +534,11 @@ def forward( # 4. mid sample = self.mid_block( - sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, ) # 5. up @@ -541,6 +559,7 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, ) @@ -840,7 +859,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): # TODO(Patrick, William) - attention mask is not used output_states = () @@ -861,10 +882,15 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample output_states += (hidden_states,) @@ -1042,6 +1068,7 @@ def forward( res_hidden_states_tuple, temb=None, encoder_hidden_states=None, + cross_attention_kwargs=None, upsample_size=None, attention_mask=None, ): @@ -1068,10 +1095,15 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1166,11 +1198,16 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): - # TODO(Patrick, William) - attention_mask is currently not used. Implement once used + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1230,6 +1267,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), ) ) resnets.append( @@ -1250,19 +1288,19 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states.transpose(1, 2), attention_mask=attention_mask, + **cross_attention_kwargs, ) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual # resnet hidden_states = resnet(hidden_states, temb) From c5e7d9e179b6120093da1d5ebe3133c7327eada1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 14:29:08 +0100 Subject: [PATCH 08/19] up --- src/diffusers/models/cross_attention_processors.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/cross_attention_processors.py b/src/diffusers/models/cross_attention_processors.py index f807b7c23bab..d7d211c96ab8 100644 --- a/src/diffusers/models/cross_attention_processors.py +++ b/src/diffusers/models/cross_attention_processors.py @@ -135,8 +135,13 @@ def set_attention_slice(self, slice_size): def set_processor(self, processor: "AttnProcessor"): self.processor = processor - def forward(self, *args, **kwargs): - return self.processor(self, *args, **kwargs) + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `CrossAttention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs + ) def batch_to_head_dim(self, tensor): head_size = self.heads From 6b8865001ecc5a0a8d9b52fdbe897f551adf7259 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 14:33:00 +0100 Subject: [PATCH 09/19] rename --- src/diffusers/models/attention.py | 2 +- .../{cross_attention_processors.py => cross_attention.py} | 0 src/diffusers/models/unet_2d_blocks.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename src/diffusers/models/{cross_attention_processors.py => cross_attention.py} (100%) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5c0f053e5db5..9fe6a8034c22 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -24,7 +24,7 @@ from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput from ..utils.import_utils import is_xformers_available -from .cross_attention_processors import CrossAttention +from .cross_attention import CrossAttention @dataclass diff --git a/src/diffusers/models/cross_attention_processors.py b/src/diffusers/models/cross_attention.py similarity index 100% rename from src/diffusers/models/cross_attention_processors.py rename to src/diffusers/models/cross_attention.py diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 6dbfe1a8033b..215ce681d95a 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -16,7 +16,7 @@ from torch import nn from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel -from .cross_attention_processors import CrossAttention, CrossAttnAddedKVProcessor +from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 86c3b0e09bbe..d5ccb169e0b4 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging -from .cross_attention_processors import AttnProcessor +from .cross_attention import AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .unet_2d_blocks import ( CrossAttnDownBlock2D, From 9c32a36ef110a0c1f01cd9265984c3eb8a5f9ef2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 14:39:31 +0100 Subject: [PATCH 10/19] correct versatile --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 5b5a5e9cbc6e..405025ddc2c5 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -7,8 +7,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...modeling_utils import ModelMixin from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel +from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor from ...models.embeddings import TimestepEmbedding, Timesteps -from ...models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn as UNetMidBlockFlatSimpleCrossAttn from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import logging @@ -1213,8 +1213,8 @@ def forward( return hidden_states -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat -class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module): +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, From a13d2a859325ab9a352e5ee471c11800efe8269a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 14:58:02 +0100 Subject: [PATCH 11/19] up --- tests/models/test_models_unet_2d.py | 65 +++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 9071495b58d2..bcddbb3efbea 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -391,6 +391,71 @@ def check_slicable_dim_attr(module: torch.nn.Module): for module in model.children(): check_slicable_dim_attr(module) + def test_special_attn_proc(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model(**inputs_dict) + + import ipdb; ipdb.set_trace() + + class AttnEasyProc(torch.nn.Module): + def __init__(self, num): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(num)) + self.is_run = False + self.number = 0 + self.counter = 0 + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states += self.weight + + self.is_run = True + self.counter += 1 + self.number = number + + return hidden_states + + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + processor = AttnEasyProc(5.0) + + # model.set_attn_processor(processor) +# model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample + model(**inputs_dict) + import ipdb; ipdb.set_trace() + class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel From 5db574e74696d487ef68e93c04b677650e3b6fe8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 15:00:08 +0100 Subject: [PATCH 12/19] up --- src/diffusers/models/cross_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index d7d211c96ab8..94a002fa0d79 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -140,7 +140,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( - self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs + self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs ) def batch_to_head_dim(self, tensor): From 9b449a8988b2fc3dcb3c46c47943028341f336d4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 16:05:32 +0100 Subject: [PATCH 13/19] up --- src/diffusers/models/cross_attention.py | 6 +++++- tests/models/test_models_unet_2d.py | 20 ++++++-------------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 94a002fa0d79..9a47430a9cb8 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -140,7 +140,11 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( - self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, ) def batch_to_head_dim(self, tensor): diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index bcddbb3efbea..91192f17fb00 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -392,16 +392,6 @@ def check_slicable_dim_attr(module: torch.nn.Module): check_slicable_dim_attr(module) def test_special_attn_proc(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - model(**inputs_dict) - - import ipdb; ipdb.set_trace() - class AttnEasyProc(torch.nn.Module): def __init__(self, num): super().__init__() @@ -451,10 +441,12 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma processor = AttnEasyProc(5.0) - # model.set_attn_processor(processor) -# model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample - model(**inputs_dict) - import ipdb; ipdb.set_trace() + model.set_attn_processor(processor) + model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample + + assert processor.counter == 12 + assert processor.is_run + assert processor.number == 123 class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): From 0e987e1edfa1504d5e94592bec6b223baf1365bb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 16:14:47 +0100 Subject: [PATCH 14/19] up --- src/diffusers/models/cross_attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 9a47430a9cb8..f9ec99ed4ee3 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -294,7 +294,7 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= class SlicedAttnProcessor: def __init__(self, slice_size): - self.slice_size = self.slice_size + self.slice_size = slice_size def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape @@ -324,8 +324,9 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - attn_slice = attn.get_attention_scores(query_slice, key_slice, attention_mask[start_idx:end_idx]) + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) @@ -383,8 +384,9 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - attn_slice = attn.get_attention_scores(query_slice, key_slice, attention_mask[start_idx:end_idx]) + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) From fbfc842769dca157a0f00da3725e234c1713aa58 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 15:56:42 +0000 Subject: [PATCH 15/19] fix --- src/diffusers/models/cross_attention.py | 28 ++++++++++++++++--------- src/diffusers/models/unet_2d_blocks.py | 6 +++--- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index f9ec99ed4ee3..c33e3c2b1919 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -120,16 +120,24 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten raise e processor = XFormersCrossAttnProcessor() - self.set_processor(processor) + else: + processor = CrossAttnProcessor() + + self.set_processor(processor) def set_attention_slice(self, slice_size): if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") - if self.added_kv_proj_dim is not None: + if slice_size is not None and self.added_kv_proj_dim is not None: processor = SlicedAttnAddedKVProcessor(slice_size) - else: + elif slice_size is not None: processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = CrossAttnAddedKVProcessor() + else: + processor = CrossAttnProcessor() + self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): @@ -203,12 +211,11 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) @@ -229,6 +236,7 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) @@ -281,7 +289,8 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj @@ -302,14 +311,12 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) - dim = query.shape[-1] + query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) @@ -344,11 +351,12 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= class SlicedAttnAddedKVProcessor: def __init__(self, slice_size): - self.slice_size = self.slice_size + self.slice_size = slice_size def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 215ce681d95a..abd7a4fe882f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -580,7 +580,7 @@ def forward( # attn hidden_states = attn( hidden_states, - encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) @@ -1366,7 +1366,7 @@ def forward( # attn hidden_states = attn( hidden_states, - encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) @@ -2186,7 +2186,7 @@ def forward( # attn hidden_states = attn( hidden_states, - encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) From 244020c85cf14426604995dd459cdd9a11d5e0ad Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 16:59:39 +0100 Subject: [PATCH 16/19] Apply suggestions from code review --- src/diffusers/models/cross_attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index c33e3c2b1919..321601b5864f 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -69,9 +69,7 @@ def __init__( # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads - self._slice_size = None - self._use_memory_efficient_attention_xformers = False self.added_kv_proj_dim = added_kv_proj_dim if norm_num_groups is not None: From 8865b18e250ee513d1e82fa0869b1bfbb63989ea Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 16:00:44 +0000 Subject: [PATCH 17/19] make style --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 405025ddc2c5..3d3f210c4183 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1297,7 +1297,7 @@ def forward( # attn hidden_states = attn( hidden_states, - encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) From 9cb7ee6185b7839871f1576515ff4e90e940baf4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 18:32:48 +0100 Subject: [PATCH 18/19] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/models/cross_attention.py | 8 ++++---- .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 321601b5864f..0e299bf1ddb8 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -204,7 +204,7 @@ def prepare_attention_mask(self, attention_mask, target_length): class CrossAttnProcessor: - def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) @@ -230,7 +230,7 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= class CrossAttnAddedKVProcessor: - def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -272,7 +272,7 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states= class XFormersCrossAttnProcessor: - def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) @@ -303,7 +303,7 @@ class SlicedAttnProcessor: def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 3d3f210c4183..d637cc432398 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -862,7 +862,7 @@ def __init__( def forward( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None ): - # TODO(Patrick, William) - attention mask is not used + # Reminder(Patrick, William) - attention mask is not used at the moment output_states = () for resnet, attn in zip(self.resnets, self.attentions): From 9d5e5ca9b18b55d864d931a4b6199c99065c5cd7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 18:36:06 +0100 Subject: [PATCH 19/19] add error message --- src/diffusers/models/cross_attention.py | 10 +++++++++- .../versatile_diffusion/modeling_text_unet.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 0e299bf1ddb8..98173cb8a406 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -95,7 +95,15 @@ def __init__( def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if use_memory_efficient_attention_xformers: - if not is_xformers_available(): + if self.added_kv_proj_dim is not None: + # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + raise NotImplementedError( + "Memory efficient attention with `xformers` is currently not supported when" + " `self.added_kv_proj_dim` is defined." + ) + elif not is_xformers_available(): raise ModuleNotFoundError( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers", diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index d637cc432398..3d3f210c4183 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -862,7 +862,7 @@ def __init__( def forward( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None ): - # Reminder(Patrick, William) - attention mask is not used at the moment + # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions):