Skip to content

Rename attention #2691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/optimization/torch2.0.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
```Python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0
from diffusers.models.attention_processor import AttnProcessor2_0

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())
Expand Down
2 changes: 1 addition & 1 deletion examples/community/stable_diffusion_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def __call__(
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Expand Down
2 changes: 1 addition & 1 deletion examples/community/stable_diffusion_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def __call__(
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def __call__(
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Expand Down
6 changes: 2 additions & 4 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -723,9 +723,7 @@ def main(args):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -561,9 +561,7 @@ def main():
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
Expand Down
6 changes: 2 additions & 4 deletions examples/research_projects/lora/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -536,9 +536,7 @@ def main():
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
Expand Down
6 changes: 2 additions & 4 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -474,9 +474,7 @@ def main():
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

unet.set_attn_processor(lora_attn_procs)

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from .models.cross_attention import LoRACrossAttnProcessor
from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging

Expand Down Expand Up @@ -207,7 +207,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]

attn_processors[key] = LoRACrossAttnProcessor(
attn_processors[key] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
attn_processors[key].load_state_dict(value_dict)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn

from ..utils.import_utils import is_xformers_available
from .cross_attention import CrossAttention
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings


Expand Down Expand Up @@ -220,7 +220,7 @@ def __init__(
)

# 1. Self-Attn
self.attn1 = CrossAttention(
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
Expand All @@ -234,7 +234,7 @@ def __init__(

# 2. Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax.numpy as jnp


class FlaxCrossAttention(nn.Module):
class FlaxAttention(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762

Expand Down Expand Up @@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module):

def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
Expand Down
Loading