Skip to content

Support pass kwargs to sd3 custom attention processor #9818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

Matrix53
Copy link
Contributor

@Matrix53 Matrix53 commented Oct 31, 2024

What does this PR do?

Fixes #8855 as #9517

As mentioned in #8855 and #9516, pass kwargs to custom attention processor is preferable when implementing layout control, attention replacing or introducing side effect to attention computation. Now this feature is supported by sd, sdxl, FLUX series, but not sd3. This PR fix the problem, and may support future training-free experiments in sd3 series models.

example code:

class MyAttentionProcessor(JointAttnProcessor2_0):
    def __init__(self, **custom_args):
        super().__init__()
        # some initialization
    
    def __call__(self, ..., added_argu=None) -> torch.FloatTensor:
        # custom attention forward pass

my_processor = MyAttentionProcessor(**custom_args)
pipe.transformer.set_attn_processor(my_processor) # sd3 pipeline
joint_attention_kwargs = {"added_argu": 100} # prepare passing added_argu to attention processor
pipe.transformer(..., joint_attention_kwargs=joint_attention_kwargs) # custom transformer forward pass

The documentation already describes usage of the joint_attention_kwargs:
image

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Hello @yiyixuxu @sayakpaul, I just want to do the same to sd3 as to FLUX in #9517

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks good for the specific changes required to support kwargs, but all the other changes are unrelated. Please revert any changes not related supporting kwargs such as multi-line import, multi-line assigments, etc. These were not required before, nor needed now due to 120 token column limit. If these were caused by your text editor automatically, please configure it to use the same settings as pyproject.yaml for this specific repo

@Matrix53 Matrix53 force-pushed the support-pass-kwargs-to-sd3-custom-attn-processor branch from f2ae8a1 to e3bed52 Compare November 1, 2024 01:49
@Matrix53
Copy link
Contributor Author

Matrix53 commented Nov 1, 2024

Thanks @a-r-r-o-w 😊, I've configured my editor and reverted all unrelated changes. Now this PR only contains necessary changes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Matrix53 Matrix53 requested a review from hlky December 9, 2024 10:48
Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: let's add joint_attention_kwargs = joint_attention_kwargs or {} instead like https://github.com/huggingface/diffusers/pull/9517/files

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Matrix53!

@hlky
Copy link
Contributor

hlky commented Dec 9, 2024

cc @sayakpaul @yiyixuxu WDYT about making default joint_attention_kwargs as {}? We end up doing joint_attention_kwargs or {} anyway, there are several checks for None which are sort of unnecessary.

@sayakpaul
Copy link
Member

Not opposed to the idea.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu yiyixuxu merged commit 8eb73c8 into huggingface:main Dec 18, 2024
15 checks passed
Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
* Support pass kwargs to sd3 custom attention processor


---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Support pass kwargs to sd3 custom attention processor


---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

"joint_attention_kwargs" don't pass the parameters to AttentionProcessor
6 participants