Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory-efficient attention (without xformers) #1892

Open
Birch-san opened this issue Jan 3, 2023 · 30 comments
Open

Memory-efficient attention (without xformers) #1892

Birch-san opened this issue Jan 3, 2023 · 30 comments
Labels

Comments

@Birch-san
Copy link
Contributor

I implemented sub-quadratic attention (as described in https://arxiv.org/abs/2112.05682v2):
https://twitter.com/Birchlabs/status/1607503573906063362
Birch-san#1
Birch-san/diffusers-play@a573e3d

is this worth upstreaming? it enables creation of images larger than can be achieved with attention slicing.

@patrickvonplaten
Copy link
Contributor

Hey @Birch-san,

Thanks a lot for the issue! @patil-suraj what do you think?

@patil-suraj
Copy link
Contributor

Very cool @Birch-san , is this more efficient than xformers ? Also, xformers installation situation is getting better now cf https://pypi.org/project/xformers/#history, so not sure if we need another efficient attention for PT. This could be a good addition to flax.

@Birch-san
Copy link
Contributor Author

Birch-san commented Jan 5, 2023

this implements the same paper as xformers memory efficient attention.
it's unlikely to be more efficient than xformers, since they have the advantage of custom CUDA kernels.

xformers is CUDA-only, I presume? no support for MPS or ROCm or Metal or CPU backends?

there are Mac users trying to run stable-diffusion on Mac Minis with 8GB of unified memory. IIRC they couldn't even fit 512x512 images. sliced attention helped with that, but this goes further: you can chunk up attention far finer, arbitrarily so.

if my calculations are correct: a 2048x2048 image's self-attention can require 80GB VRAM ordinarily. setting sliced attention to its most aggressive setting can get this down to 40GB slices (MPS backend refuses to allocate this). but the chunked attention implementation I've provided here can get it down to anything you like, e.g. 80MB chunks.

@Lime-Cakes
Copy link
Contributor

This memory efficient attention might not be faster than xformer implementation, but since it doesn't rely on custom CUDA kernel, this means better support for non-CUDA device, which would be good for mac and other non-cuda accelerator on pytorch.

@patrickvonplaten
Copy link
Contributor

@pcuenca could we maybe run some tests for MPS for this?

xformers will soon be natively supported in PyTorch so I'm wondering how important this for PyTorch-only. I definitely see an important use case for MPS though.

Also with the new attention processor system it should be relatively easy to add either way.

@patil-suraj
Copy link
Contributor

Good point regarding MPS @Birch-san ! In that case, it would be cool to have this. cc @pcuenca

@keturn
Copy link
Contributor

keturn commented Feb 3, 2023

There's another implementation here: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/modules/sub_quadratic_attention.py#L153

comfy reports:

I tweaked the sub-quadratic optimization when I implemented it in my own UI and it gave me a nice speed boost.
but I had to tweak it first cause by default it didn't give better performance than the split optimization on my 6800XT

@pcuenca
Copy link
Member

pcuenca commented Feb 6, 2023

Hi @Birch-san, sorry for being so slow to react here! Is there any chance you could submit a PR that applies these optimizations to mps? That way it would be easier for us to test and discuss. Perhaps you can use the new attention processor mechanism so people can opt in (or make it default maybe) if they are running on mps. If you can't then I'll test your branch :)

Another question I have is: will this become obsolete with the upcoming memory-efficient attention integrated in PyTorch 2.x? Or does that not work for mps at all?

@github-actions
Copy link
Contributor

github-actions bot commented Mar 2, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 2, 2023
@keturn
Copy link
Contributor

keturn commented Mar 2, 2023

Still interested, Stalebot!

Related:

@Birch-san
Copy link
Contributor Author

yup, planning to submit PRs for memory-efficient attention and cross-attention masks soon.

@patrickvonplaten
Copy link
Contributor

PyTorch 2.0 will provide memory efficient attention out of the box - is this PR still relevant then?

@Birch-san
Copy link
Contributor Author

will PyTorch 2.0 provide memory-efficient attention for Mac?

@Birch-san
Copy link
Contributor Author

Birch-san commented Mar 6, 2023

@patrickvonplaten @pcuenca

https://pytorch.org/docs/2.0/generated/torch.nn.functional.scaled_dot_product_attention.html

the docs for memory-efficient attention link to xformers, which to my knowledge does not support MPS (it's focused on CUDA and triton).

the pytorch context manager for activating memory-efficient attention is torch.backends.cuda.enable_mem_efficient_sdp(), which again is CUDA-specific.

so I think there's still a case for implementing this: for Mac, iOS, ROCm, CPU. it also means that if you trace the torchscript operations: you'd get a memory-efficient torchscript model, which you could convert to CoreML†.

if you wanted the CoreML model to be optimal for Neural Engine: you'd need to reshape the algorithm a bit; Neural Engine prefers B,C,1,S tensors over B,S,C tensors.

@pcuenca pcuenca removed the stale Issues that haven't received updates label Mar 7, 2023
@pcuenca
Copy link
Member

pcuenca commented Mar 7, 2023

@Birch-san Thanks for your interesting thoughts!

The documentation you linked to above also mentions a C++ implementation that I understand should be used when not targeting the cuda backend. I haven't had the chance to test it yet, but I'm planning to and will report back here.

Re: Core ML, if I'm not mistaken Apple's conversion code does the shape transformation when converting. We could explore whether it makes sense to support it in the diffusers codebase.

@Birch-san
Copy link
Contributor Author

Birch-san commented Mar 7, 2023

@pcuenca

Re: Core ML, if I'm not mistaken Apple's conversion code does the shape transformation when converting.

regardless of whether you target GPU or ANE: Apple traces their own bespoke UNet2DConditionModel:
https://github.com/apple/ml-stable-diffusion/blob/2c4e9de73c9e723de264356f9563706ea9104212/python_coreml_stable_diffusion/torch2coreml.py#L690-L723

Apple's bespoke UNet2DConditionModel changes every Linear layer into a Conv2D and modifies every LayerNorm, to keep tensors in [B,C,1,S] format. this is regardless of which ATTENTION_IMPLEMENTATION_IN_EFFECT is selected, and regardless of whether you target GPU or ANE.

their model offers an ATTENTION_IMPLEMENTATION_IN_EFFECT parameter, which just toggles whether sliced attention is used (to save memory — at the expense of speed — by serializing attention matmuls on batch dimension). they recommend this mode for memory-constrained devices.

my prediction: if you just want to target GPU, the default diffusers Unet would be faster (because [B,S,C] tensors are preferred by GPU -- they have a batch dimension with which you can do baddbmm() and bmm() matmuls).

also: I notice they don't fuse the * scale into the matmul. maybe CoreML is smart enough to do that for them, but if it's not: they're leaving a 18% speed boost on the table.

IIRC their coremltools implements baddbmm() support by unfusing the operation back into a matmul and a multiply, so I'm not sure whether the answer is as simple as "replace einsum() * scale with badbmm()".

if CoreML fundamentally lacks support for baddbmm(), or lacks support for automatically fusing multiplies into matmuls: there's a cheeky backdoor you can use to get a fused multiply: by burning the * scale into the projection weights of the model.

We could explore whether it makes sense to support it in the diffusers codebase.

supporting two different different possible tensor shapes throughought the entire Unet algorithm is hard to do in a clean way.

my recommendation is to un-break diffusers' support for PyTorch's _register_load_state_dict_pre_hook() idiom. this was the low-touch technique Apple originally used to modify BERT in their ANE-optimized transformers whitepaper.
#1880 (comment)

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 31, 2023
@github-actions github-actions bot closed this as completed Apr 9, 2023
@Birch-san
Copy link
Contributor Author

Birch-san commented Apr 10, 2023

Reopening; we still have a need pure-pytorch memory-efficient attention on systems such as Mac. I'm a bit tied up trying to get cross-attention bias (#2634) over the line, but still hoping to get round to upstreaming my memory-efficient attention implementation in the coming weeks.

@pcuenca pcuenca reopened this Apr 10, 2023
@pcuenca
Copy link
Member

pcuenca commented Apr 10, 2023

@Birch-san sounds great!

@pcuenca pcuenca removed the stale Issues that haven't received updates label Apr 12, 2023
@williamberman
Copy link
Contributor

JW, how hard would it be for us to instead implement as a native mps kernel in either pytorch or if there's a separate library cuda toolkit equivalent that apple maintains (and if so is it opensource?). That would be the ideal way to support instead of merging into diffusers and later going through a deprecation cycle once there's official support

@Birch-san
Copy link
Contributor Author

hmm I'm not aware of any more Mac-specific library implementation of this available.

as for whether to implement it as a native MPS kernel… well, what it does is relatively simple (it's expressible in pure PyTorch, so one could look up what underlying MPS operations that compiles down to and cram it into a kernel somehow).

MPS/Metal programming isn't my wheelhouse though, and I definitely don't see myself getting the time to learn how to write, then write, one of those.

as for the broader Mac story (CoreML export): if you wanted an access pattern optimized for Neural Engine: you'd probably want to tweak it to use ([batch, channels, 1, tokens] tensors, with tokens being contiguous and aligned to 64 bytes.
yet, whilst that's good for Neural Engine: it's probably not optimal for GPU. so you'd kinda want both options.

anyway, regardless of what the answer is for Mac: there's other backends that would benefit from a pure-pytorch implementation of memory-efficient attention. like ROCm, and CPU. and maybe backend-agnostic export formats like TorchScript or ONNX? for example for targeting WebGPU.

even CUDA users could still find this useful, despite already having two IO-aware memory-efficient kernels for attention. because neither of those bespoke kernels support forward-mode autograd. I understand from EleutherAI Discord that @crowsonkb saw use-cases for an attention implementation that supported forward-mode autograd, and right now the only way to do that is probably pure-PyTorch. so doing a memory-efficient version of that in pure-pytorch would still be useful.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label May 23, 2023
@pcuenca pcuenca added wip and removed stale Issues that haven't received updates labels May 23, 2023
@TeutonJon78
Copy link

This and/or other non-CUDA optimizations would also be helpful for DirectML users since we need every bit of extra VRAM savings.

@Beinsezii
Copy link
Contributor

Beinsezii commented Feb 28, 2024

For non-Nvidia folk I rebased https://github.com/Birch-san/diffusers/tree/subquad_attn to the diffusers master branch over at my own fork https://github.com/Beinsezii/diffusers/tree/attn_subquad

On my 7900 XTX having the query chunk at 2^12 and the kv chunk at 2^15 I can process 1024 images slightly faster than any other attention method currently working on AMD. Additionally with the memory savings I successfully ran an > 8Mpx image using tiled VAE + subquad, where previously I don't even think I could reach half of that.

The ported code is so old it doesn't have masking or a fix for the upcast softmax so it's not a magic bullet. Notably it doesn't work with the 1.5 VAE. It's probably possible to yoink a more updated subquad impl from one of the other diffusion UIs as a hacky fix for those issues, but that's a task for another day.

@Beinsezii
Copy link
Contributor

Beinsezii commented Feb 28, 2024

I updated my fork to a newer subquad impl with fixed upcast and masking, using a modified XFormers attn class since to my untrained eye they seem to function the same. Seems to work on all models now.

The updated attention function sourced from
https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/modules/sub_quadratic_attention.py
The repo is licensed under GPL 3 while the file header says its MIT so I'm not sure which it would be, ergo I don't think I should open a PR.

@bghira
Copy link
Contributor

bghira commented Mar 25, 2024

that file is MIT licensed code, as it can be applied per-file in a GPL3 project.

@tzayuan
Copy link

tzayuan commented Apr 1, 2024

Hi, @Birch-san, @Beinsezii

I would like to ask you: if my model is trained based on the xformers-based attention operator. Is it feasible for me to modify the model to a torch based attention implementation while still using the previously trained model, i.e. avoiding import xformers during inference? Thanks.

@bghira
Copy link
Contributor

bghira commented Apr 1, 2024

@Birch-san there is metal-flash-attention but i don't know how to use it. how would we integrate that here?

@Birch-san
Copy link
Contributor Author

@tzayuan yes, a model trained with xformers-based attention can be modified to use torch sdp attention. no need to import xformers, no need to retrain. one thing to be aware of is that:
xformers expects [batch, seq, heads, head_channels] permutation, whereas
torch sdp expects […batch, seq, head_channels] permutation.
in other words: xformers will do the permute for you, whereas torch expects you to permute the head dimensions to follow the batch dimensions. this will make it a little more fiddly to switch to the torch operation.

@bghira not simple at all. I'm not actually sure whether the PyTorch MPS backend can invoke custom Metal kernels. but assuming that's possible, you'd need to contribute (in the PyTorch source code, probably in C++ and Python) an MPS backend for (the forward pass of) scaled_dot_product_attention. you'd bring philip's sources into the PyTorch source tree (tell the build system to compile and link Mac targets against them), and introduce some kind of adapter from the PyTorch domain (i.e. tensors / MPS memory) to the Metal domain (however that works). then write tests, and get it reviewed & merged by the PyTorch team.

@tzayuan
Copy link

tzayuan commented Apr 2, 2024

Hi, @Birch-san
Thanks for your suggestion, I have finish it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests