From e3bed52830e838f0a4a981ffcc32c5c01f9a3aa2 Mon Sep 17 00:00:00 2001 From: Matrix53 <1079207272@qq.com> Date: Fri, 1 Nov 2024 09:47:48 +0800 Subject: [PATCH 1/6] Support pass kwargs to sd3 custom attention processor --- src/diffusers/models/attention.py | 12 +++++++++--- src/diffusers/models/transformers/transformer_sd3.py | 6 +++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..89a23478731a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -188,7 +188,11 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): self._chunk_dim = dim def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + joint_attention_kwargs: Dict[str, Any] = None, ): if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( @@ -206,7 +210,9 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + **joint_attention_kwargs, ) # Process attention outputs for the `hidden_states`. @@ -214,7 +220,7 @@ def forward( hidden_states = hidden_states + attn_output if self.use_dual_attention: - attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 hidden_states = hidden_states + attn_output2 diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index b28350b8ed9c..ccac106e2e4e 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -334,12 +334,16 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, + joint_attention_kwargs, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From 9521b7491cc97ef7982988e3e5421f2d8f75f09c Mon Sep 17 00:00:00 2001 From: Qin Zhou <1079207272@qq.com> Date: Mon, 9 Dec 2024 18:42:47 +0800 Subject: [PATCH 2/6] fix: set joint_attention_kwargs default as empty dict Co-authored-by: hlky --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 89a23478731a..e8a47fa8226a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -192,7 +192,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, - joint_attention_kwargs: Dict[str, Any] = None, + joint_attention_kwargs: Dict[str, Any] = {}, ): if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( From 7208efffe06c162927df0f7c74528548c10f9cf4 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 9 Dec 2024 11:02:31 +0000 Subject: [PATCH 3/6] set default joint_attention_kwargs in transformer_sd3 --- src/diffusers/models/transformers/transformer_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 79c4069e9a37..19c690643881 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -337,7 +337,7 @@ def forward( pooled_projections: torch.FloatTensor = None, timestep: torch.LongTensor = None, block_controlnet_hidden_states: List = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = {}, return_dict: bool = True, skip_layers: Optional[List[int]] = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: From 61925512072a746da612eef7fafe213b22d39c4a Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 9 Dec 2024 11:17:17 +0000 Subject: [PATCH 4/6] joint_attention_kwargs or {} attention.py --- src/diffusers/models/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e8a47fa8226a..41c3f50c92d4 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -192,8 +192,9 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, - joint_attention_kwargs: Dict[str, Any] = {}, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): + joint_attention_kwargs = joint_attention_kwargs or {} if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( hidden_states, emb=temb From c771d7ec8064639db794c74d6c6f4c1d61aaa305 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 9 Dec 2024 11:18:43 +0000 Subject: [PATCH 5/6] joint_attention_kwargs or {} transformer_sd3.py --- src/diffusers/models/transformers/transformer_sd3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 19c690643881..947ef675cece 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -337,7 +337,7 @@ def forward( pooled_projections: torch.FloatTensor = None, timestep: torch.LongTensor = None, block_controlnet_hidden_states: List = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = {}, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, skip_layers: Optional[List[int]] = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: @@ -369,6 +369,7 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + joint_attention_kwargs = joint_attention_kwargs or {} if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) From 7ede8f297221a20245e9de88482570958d125f6f Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 11 Dec 2024 08:57:46 +0000 Subject: [PATCH 6/6] Update src/diffusers/models/transformers/transformer_sd3.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 947ef675cece..79c4069e9a37 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -369,7 +369,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - joint_attention_kwargs = joint_attention_kwargs or {} if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0)