Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor attention.py #1880

Closed
patrickvonplaten opened this issue Jan 1, 2023 · 16 comments
Closed

Refactor attention.py #1880

patrickvonplaten opened this issue Jan 1, 2023 · 16 comments
Assignees
Labels

Comments

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jan 1, 2023

attention.py has at the moment two concurrent attention implementations which essentially do the exact same thing:

Both

class CrossAttention(nn.Module):
and
class AttentionBlock(nn.Module):
are already used for "simple" attention - e.g. the former for Stable Diffusion and the later for the simple DDPM UNet.

We should start deprecating

class AttentionBlock(nn.Module):
very soon as it's not viable to keep two attention mechanisms.

Deprecating this class won't be easy as it essentially means we have to force people to re-upload their weights. Essentially every model checkpoint that made use of

class AttentionBlock(nn.Module):
has to eventually re-upload their weights to be kept compatible.

I would propose to do this in the following way:

@williamberman williamberman self-assigned this Jan 1, 2023
@Birch-san
Copy link
Contributor

Birch-san commented Jan 3, 2023

@patrickvonplaten rather than having everybody save & re-upload their weights: can diffusers intercept the weights during model load and map them to different parameter names?

Apple uses PyTorch's _register_load_state_dict_pre_hook() idiom to intercept the weights the state dict is being loaded, transform them and redirect them to be held in different Parameters:
https://github.com/apple/ml-ane-transformers/blob/da64000fa56cc85b0859bc17cb16a3d753b8304a/ane_transformers/huggingface/distilbert.py#L241

however, something about HF's technique for model-loading breaks this idiom. my model loading hooks never get invoked. they work in a CompVis repository, but not inside HF diffusers code. I think something about using importlib to load a .bin skips it. it'd be really good if you could fix that — it's the number one thing that made it difficult for me to optimize the Diffusers Unet for Neural Engine.

in the end, this is the technique I've had to resort to to replace every AttentionBlock with CrossAttention (after model loading):
Birch-san/diffusers-play@bf9b13e
you may find this as a useful reference for how to map between them.

@williamberman
Copy link
Contributor

@Birch-san Thank you for the added context, super helpful! I don't have much to add right now. When I start working on the refactor, I'll think about it more and we can discuss :)

@Lime-Cakes
Copy link
Contributor

It seemed the two class might have some slight difference. I noticed group_norm missing from a few of the processor implementations for the class CrossAttention, which can have group_norm or without,

CrossAttnProcessor doesn't use group_norm, SlicedAttnAddedKVProcessor and CrossAttnAddedKVProcessor uses group_norm (without checking if it is actually none, which CrossAttention allows).

Where as AttentionBlock in attention.py always uses group_norm.

@williamberman
Copy link
Contributor

tl;dr

  1. Does the api design of the attention processors prohibit anything existing in CrossAttention.__call__ or should the entirety of the method exist in the processor so all of the attention mechanism is hackable. Example being residual connections.
  2. What adhoc configuration should we allow for attention processors? What about when it results in diverging defaults. I.e. making residual connections configurable now would have different defaults for different processors.
  3. While porting AttentionBlock to CrossAttention, should we make CrossAttnProcessor configurable or make a new processor that is mostly the same as CrossAttnProcessor with a residual connection? (there are potentially other changes that might need to be made, I haven't looked through all of it yet).

longer message:

QQ re api design in the attention processor: If we were to configure whether or not there is a residual connection, in an ideal world, would this occur in the processor class or in CrossAttention. Currently we just pass everything from call into the processor. Some processors have a residual connection, some do not. Because the attention block is the same dims in as dims out, the residual connection would be the same regardless of what happens in the processor and the residual connection could be done within CrossAttention.call.

However, doing the residual connection w/in CrossAttention.call means not the entirety of the attention application is hackable.

Context is that AttentionBlock has a residual connection and CrossAttnProcessor does not so I'm leaning towards adding a config to CrossAttnProcessor's constructor for it to perform a residual connection. However, if we were to make the residual connection in other processors which already have residual connections configurable, they would have different default values.

Alternatively, we could make a separate attention processor for the currently deprecated AttentionBlock that would be mostly copy-paste from the existing CrossAttnProcessor.

@williamberman
Copy link
Contributor

Follow up: residual connections would stay in the processor regardless because it isn't guaranteed for the residual connection to be the last step in the method. I.e. in AttentionBlock, we rescale the output after the residual connection is applied.

IMO, this means that regardless of commonalities, the entirety of the attention application should occur in the processor. Anything that we assume to be common to all processors is potentially a point for breakage and might need bad hacks to make work on future attention processors.

However the other questions around configuration w/ defaults still stand.

@williamberman
Copy link
Contributor

tl;dr from offline convo:

The existing CrossAttnProcessor provides attention over inputs of size (batch_size, seq_len, hidden_size) and the AttentionBlock we're deprecating provides attention over spatial inputs of (batch_size, channels, height, width) so we'll make a separate class called SpatialAttnProcessor.

For now, we'll only add self attention to the new attention processor and we can add in cross attention later. Note that we'll also change the name of CrossAttnProcessor to just AttnProcessor so the standard naming will be consistent regardless of the type of attention applied (these are internal/private classes, so changing the names should be acceptable).

We did not discuss what will happen in the future if we have to add configuration to different attention processors and that results in different default configs (i.e. the residual connection example earlier). Let's assume this is ok to not discuss for now especially as these are private classes and we'll have more flexibility if we have to make changes to them.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 21, 2023
@github-actions github-actions bot closed this as completed Mar 1, 2023
@patrickvonplaten patrickvonplaten added wip and removed stale Issues that haven't received updates labels Mar 2, 2023
@patrickvonplaten
Copy link
Contributor Author

This is still very much relevant cc @williamberman

@patrickvonplaten
Copy link
Contributor Author

We're getting too many issues / PRs about confused users. Let's try to make this high prio @williamberman

@patrickvonplaten
Copy link
Contributor Author

To begin with let's start by doing the following:

1.) Rename all processors that are called CrossAttention to just Attention
2.) Rename the file cross_attention.py to attention_processor.py,

Note: We need to keep full backwards compatibilty: We import all classes from attention_processor.py into cross_attention.py and raise a deprecation warning whenever someone imports from cross_processor.py. If classes are renamed in attention_processor.py we should import them as follows:

from .attention_processor import AttentionProcessor as CrossAttentionProcessor

@patrickvonplaten
Copy link
Contributor Author

Once that's done, let's fully continue by removing the old AttentionBlock class.

@Lime-Cakes
Copy link
Contributor

Lime-Cakes commented Mar 13, 2023

Would any change effect old model (e.g, renaming state dict key). Seems like changing/removing AttentionBlock wouldn't change pytorch state dict, as key name isn't based on class name. I'm unsure if other format such as flax and safetensor would require change though.

@patrickvonplaten
Copy link
Contributor Author

Start refactor here: #2691 (comment)

@haofanwang
Copy link
Contributor

@patrickvonplaten Is the refactoring done? I'm using a code base built on diffusers0.11, if so I can start reformatting my code.

@patrickvonplaten
Copy link
Contributor Author

#2697 will be merged very soon!

@williamberman
Copy link
Contributor

#3387 done here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants