Skip to content

[Bugfix] Renames in 0.15.0 diffusers #3184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -233,5 +233,3 @@ installer/install.sh
installer/update.bat
installer/update.sh

# no longer stored in source directory
models
15 changes: 11 additions & 4 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,15 @@ def device(self) -> torch.device:
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
values = [getattr(self, name) for name in module_names.keys()]
return [m for m in values if isinstance(m, torch.nn.Module)]
submodels = []
for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels

def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
Expand Down Expand Up @@ -472,7 +479,7 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
step_count=len(self.scheduler.timesteps)
):

yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.config.num_train_timesteps,
latents=latents)

batch_size = latents.shape[0]
Expand Down Expand Up @@ -756,7 +763,7 @@ def _tokenize(self, prompt: Union[str, List[str]]):
@property
def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels
return self.unet.config.in_channels

def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
Expand Down
17 changes: 8 additions & 9 deletions ldm/models/diffusion/cross_attention_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ldm.invoke.devices import torch_dtype


Expand Down Expand Up @@ -163,7 +162,7 @@ def offload_saved_attention_slices_to_cpu(self):

class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
Expand All @@ -178,7 +177,7 @@ def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, t
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`module` is the current Attention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Expand Down Expand Up @@ -326,7 +325,7 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode


def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
from ldm.modules.attention import CrossAttention # avoid circular import
from ldm.modules.attention import CrossAttention # avoid circular import # TODO: rename as in diffusers?
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
Expand Down Expand Up @@ -432,7 +431,7 @@ def get_mem_free_total(device):



class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -457,8 +456,8 @@ def _attention(self, query, key, value, attention_mask=None):
"""
# base implementation

class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
class AttnProcessor:
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

Expand Down Expand Up @@ -487,7 +486,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No

import torch

from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
from diffusers.models.attention_processor import Attention, AttnProcessor, SlicedAttnProcessor


@dataclass
Expand Down Expand Up @@ -532,7 +531,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):

# TODO: dynamically pick slice size based on memory conditions

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext=None):

Expand Down
1 change: 0 additions & 1 deletion ldm/models/diffusion/shared_invokeai_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import torch

from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"clip_anytorch",
"compel~=1.1.0",
"datasets",
"diffusers[torch]==0.14",
"diffusers[torch]~=0.15.0",
"dnspython==2.2.1",
"einops",
"eventlet",
Expand Down