Skip to content

Commit e531ddf

Browse files
Rename attention (huggingface#2691)
* rename file * rename attention * fix more * rename more * up * more deprecation imports * fixes
1 parent 9cee6c0 commit e531ddf

19 files changed

+816
-726
lines changed

loaders.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919

20-
from .models.cross_attention import LoRACrossAttnProcessor
20+
from .models.attention_processor import LoRAAttnProcessor
2121
from .models.modeling_utils import _get_model_file
2222
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
2323

@@ -207,7 +207,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
207207
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
208208
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
209209

210-
attn_processors[key] = LoRACrossAttnProcessor(
210+
attn_processors[key] = LoRAAttnProcessor(
211211
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
212212
)
213213
attn_processors[key].load_state_dict(value_dict)

models/attention.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch import nn
2020

2121
from ..utils.import_utils import is_xformers_available
22-
from .cross_attention import CrossAttention
22+
from .attention_processor import Attention
2323
from .embeddings import CombinedTimestepLabelEmbeddings
2424

2525

@@ -220,7 +220,7 @@ def __init__(
220220
)
221221

222222
# 1. Self-Attn
223-
self.attn1 = CrossAttention(
223+
self.attn1 = Attention(
224224
query_dim=dim,
225225
heads=num_attention_heads,
226226
dim_head=attention_head_dim,
@@ -234,7 +234,7 @@ def __init__(
234234

235235
# 2. Cross-Attn
236236
if cross_attention_dim is not None:
237-
self.attn2 = CrossAttention(
237+
self.attn2 = Attention(
238238
query_dim=dim,
239239
cross_attention_dim=cross_attention_dim,
240240
heads=num_attention_heads,

models/attention_flax.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import jax.numpy as jnp
1717

1818

19-
class FlaxCrossAttention(nn.Module):
19+
class FlaxAttention(nn.Module):
2020
r"""
2121
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
2222
@@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module):
118118

119119
def setup(self):
120120
# self attention (or cross_attention if only_cross_attention is True)
121-
self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
121+
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
122122
# cross attention
123-
self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
123+
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
124124
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
125125
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
126126
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)

0 commit comments

Comments
 (0)