Skip to content

Commit 8f12786

Browse files
committed
restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward.
1 parent b4f5cb9 commit 8f12786

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

Diff for: src/diffusers/models/unet_2d_blocks.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1750,7 +1750,7 @@ def custom_forward(*inputs):
17501750
create_custom_forward(attn, return_dict=False),
17511751
hidden_states,
17521752
encoder_hidden_states,
1753-
None, # emb
1753+
temb,
17541754
attention_mask,
17551755
cross_attention_kwargs,
17561756
encoder_attention_mask,
@@ -1761,6 +1761,7 @@ def custom_forward(*inputs):
17611761
hidden_states = attn(
17621762
hidden_states,
17631763
encoder_hidden_states=encoder_hidden_states,
1764+
emb=temb,
17641765
attention_mask=attention_mask,
17651766
cross_attention_kwargs=cross_attention_kwargs,
17661767
encoder_attention_mask=encoder_attention_mask,
@@ -2876,7 +2877,7 @@ def custom_forward(*inputs):
28762877
create_custom_forward(attn, return_dict=False),
28772878
hidden_states,
28782879
encoder_hidden_states,
2879-
None, # temb
2880+
temb,
28802881
attention_mask,
28812882
cross_attention_kwargs,
28822883
encoder_attention_mask,
@@ -2887,6 +2888,7 @@ def custom_forward(*inputs):
28872888
hidden_states = attn(
28882889
hidden_states,
28892890
encoder_hidden_states=encoder_hidden_states,
2891+
emb=temb,
28902892
attention_mask=attention_mask,
28912893
cross_attention_kwargs=cross_attention_kwargs,
28922894
encoder_attention_mask=encoder_attention_mask,
@@ -2970,6 +2972,8 @@ def forward(
29702972
self,
29712973
hidden_states: torch.FloatTensor,
29722974
encoder_hidden_states: Optional[torch.FloatTensor] = None,
2975+
# TODO: mark emb as non-optional (self.norm2 requires it).
2976+
# requires assessing impact of change to positional param interface.
29732977
emb: Optional[torch.FloatTensor] = None,
29742978
attention_mask: Optional[torch.FloatTensor] = None,
29752979
cross_attention_kwargs: Optional[Dict[str, Any]] = None,

0 commit comments

Comments
 (0)