diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6f81373376eb..9fe6a8034c22 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -24,6 +24,7 @@ from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput from ..utils.import_utils import is_xformers_available +from .cross_attention import CrossAttention @dataclass @@ -175,7 +176,14 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. @@ -213,7 +221,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu # 2. Blocks for block in self.transformer_blocks: - hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep) + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + ) # 3. Output if self.is_input_continuous: @@ -287,6 +300,20 @@ def __init__( self._use_memory_efficient_attention_xformers = False + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if use_memory_efficient_attention_xformers: if not is_xformers_available(): @@ -312,20 +339,6 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten raise e self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - def forward(self, hidden_states): residual = hidden_states batch, channel, height, width = hidden_states.shape @@ -423,7 +436,8 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, - ) # is a self-attention + ) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) # 2. Cross-Attn @@ -450,58 +464,39 @@ def __init__( # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - print("Here is how to install it") - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - if self.attn2 is not None: - self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + ): # 1. Self-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) - - if self.only_cross_attention: - hidden_states = ( - self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states - ) - else: - hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states if self.attn2 is not None: # 2. Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - hidden_states = ( - self.attn2( - norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask - ) - + hidden_states + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, ) + hidden_states = attn_output + hidden_states # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states @@ -509,229 +504,6 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, atte return hidden_states -class CrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias=False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - - self.scale = dim_head**-0.5 - - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - self._slice_size = None - self._use_memory_efficient_attention_xformers = False - self.added_kv_proj_dim = added_kv_proj_dim - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) - else: - self.group_norm = None - - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) - self.to_out.append(nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def set_attention_slice(self, slice_size): - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") - - self._slice_size = slice_size - - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - - encoder_hidden_states = encoder_hidden_states - - if self.group_norm is not None: - hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = self.to_q(hidden_states) - dim = query.shape[-1] - query = self.reshape_heads_to_batch_dim(query) - - if self.added_kv_proj_dim is not None: - key = self.to_k(hidden_states) - value = self.to_v(hidden_states) - encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) - - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) - encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) - - key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) - else: - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) - - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if attention_mask is not None: - if attention_mask.shape[-1] != query.shape[1]: - target_length = query.shape[1] - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) - - # attention, what we cannot get enough of - if self._use_memory_efficient_attention_xformers: - hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) - # Some versions of xformers return output in fp32, cast it back to the dtype of the input - hidden_states = hidden_states.to(query.dtype) - else: - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value, attention_mask) - else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - - # dropout - hidden_states = self.to_out[1](hidden_states) - return hidden_states - - def _attention(self, query, key, value, attention_mask=None): - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - - # cast back to the original dtype - attention_probs = attention_probs.to(value.dtype) - - # compute attention output - hidden_states = torch.bmm(attention_probs, value) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - - if self.upcast_attention: - query_slice = query_slice.float() - key_slice = key_slice.float() - - attn_slice = torch.baddbmm( - torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), - query_slice, - key_slice.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - if attention_mask is not None: - attn_slice = attn_slice + attention_mask[start_idx:end_idx] - - if self.upcast_softmax: - attn_slice = attn_slice.float() - - attn_slice = attn_slice.softmax(dim=-1) - - # cast back to the original dtype - attn_slice = attn_slice.to(value.dtype) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): - # TODO attention_mask - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - class FeedForward(nn.Module): r""" A feed-forward layer. diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py new file mode 100644 index 000000000000..98173cb8a406 --- /dev/null +++ b/src/diffusers/models/cross_attention.py @@ -0,0 +1,428 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils.import_utils import is_xformers_available + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + processor: Optional["AttnProcessor"] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + processor = processor if processor is not None else CrossAttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if use_memory_efficient_attention_xformers: + if self.added_kv_proj_dim is not None: + # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + raise NotImplementedError( + "Memory efficient attention with `xformers` is currently not supported when" + " `self.added_kv_proj_dim` is defined." + ) + elif not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + processor = XFormersCrossAttnProcessor() + else: + processor = CrossAttnProcessor() + + self.set_processor(processor) + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = CrossAttnAddedKVProcessor() + else: + processor = CrossAttnProcessor() + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor"): + self.processor = processor + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `CrossAttention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor): + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def get_attention_scores(self, query, key, attention_mask=None): + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask(self, attention_mask, target_length): + head_size = self.heads + if attention_mask is None: + return attention_mask + + if attention_mask.shape[-1] != target_length: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + return attention_mask + + +class CrossAttnProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class CrossAttnAddedKVProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class XFormersCrossAttnProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class SlicedAttnProcessor: + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range(hidden_states.shape[0] // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class SlicedAttnAddedKVProcessor: + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + for i in range(hidden_states.shape[0] // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +AttnProcessor = Union[ + CrossAttnProcessor, + XFormersCrossAttnProcessor, + SlicedAttnProcessor, + CrossAttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, +] diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index a3a9d39bc8ae..abd7a4fe882f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,8 @@ import torch from torch import nn -from .attention import AttentionBlock, CrossAttention, DualTransformer2DModel, Transformer2DModel +from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel +from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -481,11 +482,16 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): - # TODO(Patrick, William) - attention_mask is currently not used. Implement once used + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample hidden_states = resnet(hidden_states, temb) return hidden_states @@ -544,6 +550,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), ) ) resnets.append( @@ -564,19 +571,19 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = attn( hidden_states, - encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + **cross_attention_kwargs, ) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual # resnet hidden_states = resnet(hidden_states, temb) @@ -750,7 +757,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): # TODO(Patrick, William) - attention mask is not used output_states = () @@ -771,10 +780,15 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample output_states += (hidden_states,) @@ -1310,6 +1324,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), ) ) self.attentions = nn.ModuleList(attentions) @@ -1338,23 +1353,23 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): output_states = () + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): # resnet hidden_states = resnet(hidden_states, temb) # attn - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = attn( hidden_states, - encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + **cross_attention_kwargs, ) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual output_states += (hidden_states,) @@ -1531,6 +1546,7 @@ def forward( res_hidden_states_tuple, temb=None, encoder_hidden_states=None, + cross_attention_kwargs=None, upsample_size=None, attention_mask=None, ): @@ -1557,10 +1573,15 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2113,6 +2134,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), ) ) self.attentions = nn.ModuleList(attentions) @@ -2149,7 +2171,9 @@ def forward( encoder_hidden_states=None, upsample_size=None, attention_mask=None, + cross_attention_kwargs=None, ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states @@ -2160,15 +2184,12 @@ def forward( hidden_states = resnet(hidden_states, temb) # attn - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = attn( hidden_states, - encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + **cross_attention_kwargs, ) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4612e650e4e4..d5ccb169e0b4 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -21,6 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging +from .cross_attention import AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -265,6 +266,18 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + def set_attn_processor(self, processor: AttnProcessor): + # set recursively + def fn_recursive_attn_processor(module: torch.nn.Module): + if hasattr(module, "set_processor"): + module.set_processor(processor) + + for child in module.children(): + fn_recursive_attn_processor(child) + + for module in self.children(): + fn_recursive_attn_processor(module) + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -341,6 +354,7 @@ def forward( encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -426,6 +440,7 @@ def forward( temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -434,7 +449,11 @@ def forward( # 4. mid sample = self.mid_block( - sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, ) # 5. up @@ -455,6 +474,7 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, ) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index c83a347fe93b..3d3f210c4183 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -7,8 +7,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...modeling_utils import ModelMixin from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel +from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor from ...models.embeddings import TimestepEmbedding, Timesteps -from ...models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn as UNetMidBlockFlatSimpleCrossAttn from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import logging @@ -351,6 +351,18 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) + def set_attn_processor(self, processor: AttnProcessor): + # set recursively + def fn_recursive_attn_processor(module: torch.nn.Module): + if hasattr(module, "set_processor"): + module.set_processor(processor) + + for child in module.children(): + fn_recursive_attn_processor(child) + + for module in self.children(): + fn_recursive_attn_processor(module) + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -427,6 +439,7 @@ def forward( encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -512,6 +525,7 @@ def forward( temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -520,7 +534,11 @@ def forward( # 4. mid sample = self.mid_block( - sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, ) # 5. up @@ -541,6 +559,7 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, ) @@ -840,7 +859,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): # TODO(Patrick, William) - attention mask is not used output_states = () @@ -861,10 +882,15 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample output_states += (hidden_states,) @@ -1042,6 +1068,7 @@ def forward( res_hidden_states_tuple, temb=None, encoder_hidden_states=None, + cross_attention_kwargs=None, upsample_size=None, attention_mask=None, ): @@ -1068,10 +1095,15 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1166,18 +1198,23 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): - # TODO(Patrick, William) - attention_mask is currently not used. Implement once used + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample hidden_states = resnet(hidden_states, temb) return hidden_states -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat -class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module): +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, @@ -1230,6 +1267,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, + processor=CrossAttnAddedKVProcessor(), ) ) resnets.append( @@ -1250,19 +1288,19 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = attn( hidden_states, - encoder_hidden_states=encoder_hidden_states.transpose(1, 2), + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + **cross_attention_kwargs, ) - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual # resnet hidden_states = resnet(hidden_states, temb) diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 9071495b58d2..91192f17fb00 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -391,6 +391,63 @@ def check_slicable_dim_attr(module: torch.nn.Module): for module in model.children(): check_slicable_dim_attr(module) + def test_special_attn_proc(self): + class AttnEasyProc(torch.nn.Module): + def __init__(self, num): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(num)) + self.is_run = False + self.number = 0 + self.counter = 0 + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states += self.weight + + self.is_run = True + self.counter += 1 + self.number = number + + return hidden_states + + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + processor = AttnEasyProc(5.0) + + model.set_attn_processor(processor) + model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample + + assert processor.counter == 12 + assert processor.is_run + assert processor.number == 123 + class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel