Skip to content

Implement .swap() against diffusers 0.12 #2385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0c2a511
wip SwapCrossAttnProcessor
damian0815 Jan 21, 2023
bffe199
SwapCrossAttnProcessor working - tested on mac CPU (MPS doesn't work)
damian0815 Jan 21, 2023
313b206
squash float16/float32 mismatch on linux
damian0815 Jan 22, 2023
c0610f7
pass missing value
damian0815 Jan 22, 2023
63c6019
sliced attention processor wip (untested)
damian0815 Jan 24, 2023
3c53b46
Merge branch 'main' into diffusers_cross_attention_control_reimplemen…
keturn Jan 25, 2023
a4aea15
more wip sliced attention (.swap doesn't work yet)
damian0815 Jan 25, 2023
c52dd7e
Merge branch 'diffusers_cross_attention_control_reimplementation' of …
damian0815 Jan 25, 2023
1f5ad1b
sliced swap working
damian0815 Jan 25, 2023
34a3f4a
cleanup
damian0815 Jan 25, 2023
41aed57
wip tracking down MPS slicing support
damian0815 Jan 25, 2023
95d147c
MPS support: negatory
damian0815 Jan 25, 2023
93a2444
Merge remote-tracking branch 'upstream/main' into diffusers_cross_att…
damian0815 Jan 25, 2023
5e7ed96
wip updating docs
damian0815 Jan 25, 2023
8ed8bf5
use 'auto' slice size
damian0815 Jan 26, 2023
7297526
trying out JPPhoto's patch on vast.ai
damian0815 Jan 26, 2023
fb312f9
use the correct value - whoops
damian0815 Jan 26, 2023
c381788
don't restore None
damian0815 Jan 26, 2023
e090c0d
try without setting every time
damian0815 Jan 26, 2023
5ce62e0
Merge branch 'main' into diffusers_cross_attention_control_reimplemen…
JPPhoto Jan 29, 2023
27ee939
with diffusers cac, always run the original prompt on the first step
damian0815 Jan 30, 2023
c5c160a
Merge branch 'diffusers_cross_attention_control_reimplementation' of …
damian0815 Jan 30, 2023
478c379
for cac make t_start=0.1 the default
damian0815 Jan 30, 2023
17d73d0
Revert "with diffusers cac, always run the original prompt on the fir…
damian0815 Jan 30, 2023
3f1120e
Merge branch 'main' into diffusers_cross_attention_control_reimplemen…
damian0815 Jan 30, 2023
d044d4c
rename override/restore methods to better reflect what they actually do
damian0815 Jan 30, 2023
817e36f
Merge branch 'diffusers_cross_attention_control_reimplementation' of …
damian0815 Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions docs/features/PROMPTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,28 +239,24 @@ Generate an image with a given prompt, record the seed of the image, and then
use the `prompt2prompt` syntax to substitute words in the original prompt for
words in a new prompt. This works for `img2img` as well.

- `a ("fluffy cat").swap("smiling dog") eating a hotdog`.
- quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`.
- for single word substitutions parentheses are also optional:
`a cat.swap(dog) eating a hotdog`.
- Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely
corresponding to bloc97's `prompt_edit_spatial_start/_end` and
`prompt_edit_tokens_start/_end` but with the math swapped to make it easier to
intuitively understand.
- Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end`
argument means that the "spatial" (self-attention) edit will stop having any
effect after 30% (=0.3) of the steps have been done, leaving Stable
Diffusion with 70% of the steps where it is free to decide for itself how to
reshape the cat-form into a dog form.
- The numbers represent a percentage through the step sequence where the edits
should happen. 0 means the start (noisy starting image), 1 is the end (final
image).
- For img2img, the step sequence does not start at 0 but instead at
(1-strength) - so if strength is 0.7, s_start and s_end must both be
greater than 0.3 (1-0.7) to have any effect.
- Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable
Diffusion should have to change the shape of the subject being swapped.
- `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.
For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because of the word words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
- `a cat playing with a ball in the forest`
- `a dog playing with a ball in the forest`

| `a cat playing with a ball in the forest` | `a dog playing with a ball in the forest` |
| --- | --- |
| img | img |


- For multiple word swaps, use parentheses: `a (fluffy cat).swap(barking dog) playing with a ball in the forest`.
- To swap a comma, use quotes: `a ("fluffy, grey cat").swap("big, barking dog") playing with a ball in the forest`.
- Supports options `t_start` and `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to
intuitively understand. `t_start` and `t_end` are used to control on which steps cross-attention control should run. With the default values `t_start=0` and `t_end=1`, cross-attention control is active on every step of image generation. Other values can be used to turn cross-attention control off for part of the image generation process.
- For example, if doing a diffusion with 10 steps for the prompt is `a cat.swap(dog, t_start=0.3, t_end=1.0) playing with a ball in the forest`, the first 3 steps will be run as `a cat playing with a ball in the forest`, while the last 7 steps will run as `a dog playing with a ball in the forest`, but the pixels that represent `dog` will be locked to the pixels that would have represented `cat` if the `cat` prompt had been used instead.
- Conversely, for `a cat.swap(dog, t_start=0, t_end=0.7) playing with a ball in the forest`, the first 7 steps will run as `a dog playing with a ball in the forest` with the pixels that represent `dog` locked to the same pixels that would have represented `cat` if the `cat` prompt was being used instead. The final 3 steps will just run `a cat playing with a ball in the forest`.
> For img2img, the step sequence does not start at 0 but instead at `(1.0-strength)` - so if the img2img `strength` is `0.7`, `t_start` and `t_end` must both be greater than `0.3` (`1.0-0.7`) to have any effect.

Prompt2prompt `.swap()` is not compatible with xformers, which will be temporarily disabled when doing a `.swap()` - so you should expect to use more VRAM and run slower that with xformers enabled.

The `prompt2prompt` code is based off
[bloc97's colab](https://github.com/bloc97/CrossAttentionControl).
Expand Down
97 changes: 55 additions & 42 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter

# monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
Expand Down Expand Up @@ -295,7 +292,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
use_full_precision = (precision == 'float32' or precision == 'autocast')
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
Expand All @@ -307,8 +304,23 @@ def __init__(
textual_inversion_manager=self.textual_inversion_manager
)

self._enable_memory_efficient_attention()


def _enable_memory_efficient_attention(self):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention()
else:
if torch.backends.mps.is_available():
# until pytorch #91617 is fixed, slicing is borked on MPS
# https://github.com/pytorch/pytorch/issues/91617
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else:
self.enable_attention_slicing(slice_size='auto')

def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
Expand Down Expand Up @@ -373,42 +385,40 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
if additional_guidance is None:
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
step_count=len(self.scheduler.timesteps))
else:
self.invokeai_diffuser.remove_cross_attention_control()

yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)

batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
latents = self.scheduler.add_noise(latents, noise, batched_t)

attention_map_saver: Optional[AttentionMapSaver] = None
self.invokeai_diffuser.remove_attention_map_saving()
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)

if i == len(timesteps)-1 and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index)
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)

yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original, attention_map_saver=attention_map_saver)

self.invokeai_diffuser.remove_attention_map_saving()
return latents, attention_map_saver
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps)
):

yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)

batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
latents = self.scheduler.add_noise(latents, noise, batched_t)

attention_map_saver: Optional[AttentionMapSaver] = None

for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)

# TODO resuscitate attention map saving
#if i == len(timesteps)-1 and extra_conditioning_info is not None:
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)

yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original, attention_map_saver=attention_map_saver)

return latents, attention_map_saver

@torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor,
Expand Down Expand Up @@ -447,7 +457,7 @@ def step(self, t: torch.Tensor, latents: torch.Tensor,

return step_output

def _unet_forward(self, latents, t, text_embeddings):
def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
Expand All @@ -460,7 +470,10 @@ def _unet_forward(self, latents, t, text_embeddings):
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents)

return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
return self.unet(sample=latents,
timestep=t,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=cross_attention_kwargs).sample

def img2img_from_embeddings(self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],
Expand Down
2 changes: 1 addition & 1 deletion ldm/invoke/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self, original: list, edited: list, options: dict=None):
default_options = {
's_start': 0.0,
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
't_start': 0.0,
't_start': 0.1,
't_end': 1.0
}
merged_options = default_options
Expand Down
Loading