Skip to content

Commit a2f6ced

Browse files
patrickvonplatenJimmy
authored and
Jimmy
committed
Rename attention (huggingface#2691)
* rename file * rename attention * fix more * rename more * up * more deprecation imports * fixes
1 parent cf6522c commit a2f6ced

28 files changed

+835
-761
lines changed

docs/source/en/optimization/torch2.0.mdx

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
5050
```Python
5151
import torch
5252
from diffusers import StableDiffusionPipeline
53-
from diffusers.models.cross_attention import AttnProcessor2_0
53+
from diffusers.models.attention_processor import AttnProcessor2_0
5454

5555
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
5656
pipe.unet.set_attn_processor(AttnProcessor2_0())

examples/community/stable_diffusion_controlnet_img2img.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def __call__(
713713
The frequency at which the `callback` function will be called. If not specified, the callback will be
714714
called at every step.
715715
cross_attention_kwargs (`dict`, *optional*):
716-
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
716+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
717717
`self.processor` in
718718
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
719719
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):

examples/community/stable_diffusion_controlnet_inpaint.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ def __call__(
868868
The frequency at which the `callback` function will be called. If not specified, the callback will be
869869
called at every step.
870870
cross_attention_kwargs (`dict`, *optional*):
871-
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
871+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
872872
`self.processor` in
873873
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
874874
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):

examples/community/stable_diffusion_controlnet_inpaint_img2img.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ def __call__(
911911
The frequency at which the `callback` function will be called. If not specified, the callback will be
912912
called at every step.
913913
cross_attention_kwargs (`dict`, *optional*):
914-
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
914+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
915915
`self.processor` in
916916
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
917917
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):

examples/dreambooth/train_dreambooth_lora.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
UNet2DConditionModel,
4848
)
4949
from diffusers.loaders import AttnProcsLayers
50-
from diffusers.models.cross_attention import LoRACrossAttnProcessor
50+
from diffusers.models.attention_processor import LoRAAttnProcessor
5151
from diffusers.optimization import get_scheduler
5252
from diffusers.utils import check_min_version, is_wandb_available
5353
from diffusers.utils.import_utils import is_xformers_available
@@ -723,9 +723,7 @@ def main(args):
723723
block_id = int(name[len("down_blocks.")])
724724
hidden_size = unet.config.block_out_channels[block_id]
725725

726-
lora_attn_procs[name] = LoRACrossAttnProcessor(
727-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
728-
)
726+
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
729727

730728
unet.set_attn_processor(lora_attn_procs)
731729
lora_layers = AttnProcsLayers(unet.attn_processors)

examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
2424
from diffusers.loaders import AttnProcsLayers
25-
from diffusers.models.cross_attention import LoRACrossAttnProcessor
25+
from diffusers.models.attention_processor import LoRAAttnProcessor
2626
from diffusers.optimization import get_scheduler
2727
from diffusers.utils import check_min_version
2828
from diffusers.utils.import_utils import is_xformers_available
@@ -561,9 +561,7 @@ def main():
561561
block_id = int(name[len("down_blocks.")])
562562
hidden_size = unet.config.block_out_channels[block_id]
563563

564-
lora_attn_procs[name] = LoRACrossAttnProcessor(
565-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
566-
)
564+
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
567565

568566
unet.set_attn_processor(lora_attn_procs)
569567
lora_layers = AttnProcsLayers(unet.attn_processors)

examples/research_projects/lora/train_text_to_image_lora.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import diffusers
4444
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
4545
from diffusers.loaders import AttnProcsLayers
46-
from diffusers.models.cross_attention import LoRACrossAttnProcessor
46+
from diffusers.models.attention_processor import LoRAAttnProcessor
4747
from diffusers.optimization import get_scheduler
4848
from diffusers.utils import check_min_version, is_wandb_available
4949
from diffusers.utils.import_utils import is_xformers_available
@@ -536,9 +536,7 @@ def main():
536536
block_id = int(name[len("down_blocks.")])
537537
hidden_size = unet.config.block_out_channels[block_id]
538538

539-
lora_attn_procs[name] = LoRACrossAttnProcessor(
540-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
541-
)
539+
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
542540

543541
unet.set_attn_processor(lora_attn_procs)
544542
lora_layers = AttnProcsLayers(unet.attn_processors)

examples/text_to_image/train_text_to_image_lora.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import diffusers
4242
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
4343
from diffusers.loaders import AttnProcsLayers
44-
from diffusers.models.cross_attention import LoRACrossAttnProcessor
44+
from diffusers.models.attention_processor import LoRAAttnProcessor
4545
from diffusers.optimization import get_scheduler
4646
from diffusers.utils import check_min_version, is_wandb_available
4747
from diffusers.utils.import_utils import is_xformers_available
@@ -474,9 +474,7 @@ def main():
474474
block_id = int(name[len("down_blocks.")])
475475
hidden_size = unet.config.block_out_channels[block_id]
476476

477-
lora_attn_procs[name] = LoRACrossAttnProcessor(
478-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
479-
)
477+
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
480478

481479
unet.set_attn_processor(lora_attn_procs)
482480

src/diffusers/loaders.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919

20-
from .models.cross_attention import LoRACrossAttnProcessor
20+
from .models.attention_processor import LoRAAttnProcessor
2121
from .models.modeling_utils import _get_model_file
2222
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
2323

@@ -207,7 +207,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
207207
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
208208
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
209209

210-
attn_processors[key] = LoRACrossAttnProcessor(
210+
attn_processors[key] = LoRAAttnProcessor(
211211
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
212212
)
213213
attn_processors[key].load_state_dict(value_dict)

src/diffusers/models/attention.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch import nn
2020

2121
from ..utils.import_utils import is_xformers_available
22-
from .cross_attention import CrossAttention
22+
from .attention_processor import Attention
2323
from .embeddings import CombinedTimestepLabelEmbeddings
2424

2525

@@ -220,7 +220,7 @@ def __init__(
220220
)
221221

222222
# 1. Self-Attn
223-
self.attn1 = CrossAttention(
223+
self.attn1 = Attention(
224224
query_dim=dim,
225225
heads=num_attention_heads,
226226
dim_head=attention_head_dim,
@@ -234,7 +234,7 @@ def __init__(
234234

235235
# 2. Cross-Attn
236236
if cross_attention_dim is not None:
237-
self.attn2 = CrossAttention(
237+
self.attn2 = Attention(
238238
query_dim=dim,
239239
cross_attention_dim=cross_attention_dim,
240240
heads=num_attention_heads,

src/diffusers/models/attention_flax.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import jax.numpy as jnp
1717

1818

19-
class FlaxCrossAttention(nn.Module):
19+
class FlaxAttention(nn.Module):
2020
r"""
2121
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
2222
@@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module):
118118

119119
def setup(self):
120120
# self attention (or cross_attention if only_cross_attention is True)
121-
self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
121+
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
122122
# cross attention
123-
self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
123+
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
124124
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
125125
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
126126
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)

0 commit comments

Comments
 (0)