Skip to content

Higher VRAM usage with PyTorch2 without xformers under certain situations #3441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wfng92 opened this issue May 16, 2023 · 3 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@wfng92
Copy link
Contributor

wfng92 commented May 16, 2023

Describe the bug

Edited to reflect the actual issue

In the latest development version, PR #3365 introduced a confusion which makes users believe that PT2 variant of attention processors are fully supported and xformers is no longer needed. This results in higher VRAM usage under certain situations (using LoRA/custom diffusion without xformers)

If the env is installed with both Pytorch2 and xformers, it will

  • raise a warning
  • default to PyTorch's native efficient flash attention

diffusers currently supports the following PT 2.0 variant of attention processors

  • AttnProcessor => AttnProcessor2_0
  • AttnAddedKVProcessor => AttnAddedKVProcessor2_0

The following are not supported:

  • SlicedAttnProcessor
  • SlicedAttnAddedKVProcessor
  • LoRAAttnProcessor
  • CustomDiffusionAttnProcessor

It would be great if users can still use xformers when calling the pipe.enable_xformers_memory_efficient_attention() function.

Reproduction

from diffusers import (
    DPMSolverMultistepScheduler,
    StableDiffusionPipeline,
)

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16
).to("cuda")

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

pipe.unet.load_attn_procs("pytorch_lora_weights.bin")
pipe.enable_xformers_memory_efficient_attention()

prompt = "a photo of a dog"
image = pipe(prompt=prompt, cross_attention_kwargs={"scale": 1.0}).images

Logs

"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0.

System Info

  • diffusers version: 0.17.0.dev0
  • Platform: Windows-10-10.0.19045-SP0
  • Python version: 3.10.11
  • PyTorch version (GPU?): 2.0.0+cu118 (True)
  • Huggingface_hub version: 0.13.4
  • Transformers version: 4.28.1
  • Accelerate version: 0.18.0
  • xFormers version: 0.0.19
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No
@wfng92 wfng92 added the bug Something isn't working label May 16, 2023
@patrickvonplaten
Copy link
Contributor

Thanks for the issue @wfng92, agree we should maybe not force disable xformers - @sayakpaul can we maybe revert/change your PR here?

@sayakpaul sayakpaul self-assigned this May 17, 2023
@sayakpaul
Copy link
Member

@patrickvonplaten on it. I will drop a follow-up PR to clear this regression.

@sayakpaul
Copy link
Member

@wfng92 opened #3457. Let's continue the discussion there :)

@wfng92 wfng92 closed this as completed May 17, 2023
@wfng92 wfng92 changed the title Higher VRAM usage with PyTorch2 under certain situations Higher VRAM usage with PyTorch2 without xformers under certain situations May 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants