File tree 2 files changed +6
-3
lines changed
tests/pipelines/stable_diffusion
2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -271,9 +271,10 @@ def __init__(
271
271
def forward (
272
272
self ,
273
273
hidden_states ,
274
+ attention_mask = None ,
274
275
encoder_hidden_states = None ,
276
+ encoder_attention_mask = None ,
275
277
timestep = None ,
276
- attention_mask = None ,
277
278
cross_attention_kwargs = None ,
278
279
class_labels = None ,
279
280
):
@@ -302,12 +303,14 @@ def forward(
302
303
norm_hidden_states = (
303
304
self .norm2 (hidden_states , timestep ) if self .use_ada_layer_norm else self .norm2 (hidden_states )
304
305
)
306
+ # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
307
+ # prepare attention mask here
305
308
306
309
# 2. Cross-Attention
307
310
attn_output = self .attn2 (
308
311
norm_hidden_states ,
309
312
encoder_hidden_states = encoder_hidden_states ,
310
- attention_mask = attention_mask ,
313
+ attention_mask = encoder_attention_mask ,
311
314
** cross_attention_kwargs ,
312
315
)
313
316
hidden_states = attn_output + hidden_states
Original file line number Diff line number Diff line change @@ -737,7 +737,7 @@ def test_stable_diffusion_vae_tiling(self):
737
737
738
738
# make sure that more than 4 GB is allocated
739
739
mem_bytes = torch .cuda .max_memory_allocated ()
740
- assert mem_bytes > 4e9
740
+ assert mem_bytes > 5e9
741
741
assert np .abs (image_chunked .flatten () - image .flatten ()).max () < 1e-2
742
742
743
743
def test_stable_diffusion_fp16_vs_autocast (self ):
You can’t perform that action at this time.
0 commit comments