@@ -1750,7 +1750,7 @@ def custom_forward(*inputs):
1750
1750
create_custom_forward (attn , return_dict = False ),
1751
1751
hidden_states ,
1752
1752
encoder_hidden_states ,
1753
- None , # emb
1753
+ temb ,
1754
1754
attention_mask ,
1755
1755
cross_attention_kwargs ,
1756
1756
encoder_attention_mask ,
@@ -1761,6 +1761,7 @@ def custom_forward(*inputs):
1761
1761
hidden_states = attn (
1762
1762
hidden_states ,
1763
1763
encoder_hidden_states = encoder_hidden_states ,
1764
+ emb = temb ,
1764
1765
attention_mask = attention_mask ,
1765
1766
cross_attention_kwargs = cross_attention_kwargs ,
1766
1767
encoder_attention_mask = encoder_attention_mask ,
@@ -2876,7 +2877,7 @@ def custom_forward(*inputs):
2876
2877
create_custom_forward (attn , return_dict = False ),
2877
2878
hidden_states ,
2878
2879
encoder_hidden_states ,
2879
- None , # temb
2880
+ temb ,
2880
2881
attention_mask ,
2881
2882
cross_attention_kwargs ,
2882
2883
encoder_attention_mask ,
@@ -2887,6 +2888,7 @@ def custom_forward(*inputs):
2887
2888
hidden_states = attn (
2888
2889
hidden_states ,
2889
2890
encoder_hidden_states = encoder_hidden_states ,
2891
+ emb = temb ,
2890
2892
attention_mask = attention_mask ,
2891
2893
cross_attention_kwargs = cross_attention_kwargs ,
2892
2894
encoder_attention_mask = encoder_attention_mask ,
@@ -2970,6 +2972,8 @@ def forward(
2970
2972
self ,
2971
2973
hidden_states : torch .FloatTensor ,
2972
2974
encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
2975
+ # TODO: mark emb as non-optional (self.norm2 requires it).
2976
+ # requires assessing impact of change to positional param interface.
2973
2977
emb : Optional [torch .FloatTensor ] = None ,
2974
2978
attention_mask : Optional [torch .FloatTensor ] = None ,
2975
2979
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
0 commit comments