Skip to content

Commit 4ae54b3

Browse files
[attention] Fix attention (#2656)
* [attention] Fix attention * fix * correct
1 parent fa7a576 commit 4ae54b3

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

Diff for: src/diffusers/models/attention.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,10 @@ def __init__(
271271
def forward(
272272
self,
273273
hidden_states,
274+
attention_mask=None,
274275
encoder_hidden_states=None,
276+
encoder_attention_mask=None,
275277
timestep=None,
276-
attention_mask=None,
277278
cross_attention_kwargs=None,
278279
class_labels=None,
279280
):
@@ -302,12 +303,14 @@ def forward(
302303
norm_hidden_states = (
303304
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
304305
)
306+
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
307+
# prepare attention mask here
305308

306309
# 2. Cross-Attention
307310
attn_output = self.attn2(
308311
norm_hidden_states,
309312
encoder_hidden_states=encoder_hidden_states,
310-
attention_mask=attention_mask,
313+
attention_mask=encoder_attention_mask,
311314
**cross_attention_kwargs,
312315
)
313316
hidden_states = attn_output + hidden_states

Diff for: tests/pipelines/stable_diffusion/test_stable_diffusion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def test_stable_diffusion_vae_tiling(self):
737737

738738
# make sure that more than 4 GB is allocated
739739
mem_bytes = torch.cuda.max_memory_allocated()
740-
assert mem_bytes > 4e9
740+
assert mem_bytes > 5e9
741741
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
742742

743743
def test_stable_diffusion_fp16_vs_autocast(self):

0 commit comments

Comments
 (0)