Skip to content

refactor: move model helper function in pipeline to a mixin class #6571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
130be1d
move model helper function in pipeline to EfficiencyMixin
ultranity Jan 11, 2024
ec74982
deduplicate functions replaced by EfficiencyMixin
ultranity Jan 11, 2024
4a7fc38
add mixin to rdm & restore audioldm2 & fix quality checks
ultranity Jan 15, 2024
cc4f805
rebase on main branch
ultranity Feb 19, 2024
fc71e97
init PipelineEfficiencyFunctionTesterMixin
ultranity Feb 19, 2024
95f53e6
rename EfficiencyMixin to LatentDiffusionMixin
ultranity Feb 21, 2024
6c11d6a
add LDM_component test for pipeline with LatentDiffusionMixin
ultranity Feb 21, 2024
48d8b67
Merge branch 'main' into efficiency_pipe_util
yiyixuxu Feb 22, 2024
4602bac
rename EfficiencyMixin to StableDiffusionMixin
ultranity Feb 22, 2024
7373eb0
Merge branch 'main' into efficiency_pipe_util
ultranity Feb 23, 2024
ebfd3a7
Add more SDFunctionTesterMixin to cover different UNet type
ultranity Feb 23, 2024
066c14d
add StableDiffusionMixin to InstaFlowPipeline
ultranity Feb 23, 2024
fdc43c5
remove StableDiffusionMixin from UniDiffuserPipeline
ultranity Feb 23, 2024
0099144
Update tests/pipelines/test_pipelines_common.py
yiyixuxu Feb 24, 2024
4a294ec
make SDFunctionTesterMixin run on non-image diffsuion pipeline
ultranity Feb 25, 2024
b3c3de0
fix fuse_projection by check is_cross_attention when init
ultranity Feb 25, 2024
9920fc9
Merge branch 'main' into efficiency_pipe_util
ultranity Feb 27, 2024
7cdff34
Merge branch 'main' into efficiency_pipe_util
sayakpaul Feb 27, 2024
994299c
use get_dummy_inputs for test_vae_tiling and test_freeu
ultranity Feb 27, 2024
a076831
fix I2V gen test error
ultranity Feb 27, 2024
661b1b5
Merge branch 'main' into efficiency_pipe_util
ultranity Feb 27, 2024
0fd684b
add missing StableDiffusionMixin
ultranity Feb 28, 2024
4bf6b55
Merge branch 'main' into efficiency_pipe_util
ultranity Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions examples/community/clip_guided_images_mixing_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.torch_utils import randn_tensor
Expand Down Expand Up @@ -77,7 +77,7 @@ def set_requires_grad(model, value):
param.requires_grad = value


class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
def __init__(
self,
vae: AutoencoderKL,
Expand Down Expand Up @@ -113,16 +113,6 @@ def __init__(
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
self.enable_attention_slicing(None)

def freeze_vae(self):
set_requires_grad(self.vae, False)

Expand Down
14 changes: 2 additions & 12 deletions examples/community/clip_guided_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput


Expand Down Expand Up @@ -51,7 +51,7 @@ def set_requires_grad(model, value):
param.requires_grad = value


class CLIPGuidedStableDiffusion(DiffusionPipeline):
class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
- https://github.com/Jack000/glid-3-xl
- https://github.dev/crowsonkb/k-diffusion
Expand Down Expand Up @@ -89,16 +89,6 @@ def __init__(
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
self.enable_attention_slicing(None)

def freeze_vae(self):
set_requires_grad(self.vae, False)

Expand Down
14 changes: 2 additions & 12 deletions examples/community/clip_guided_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import PIL_INTERPOLATION, deprecate
from diffusers.utils.torch_utils import randn_tensor
Expand Down Expand Up @@ -125,7 +125,7 @@ def set_requires_grad(model, value):
param.requires_grad = value


class CLIPGuidedStableDiffusion(DiffusionPipeline):
class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
- https://github.com/Jack000/glid-3-xl
- https://github.dev/crowsonkb/k-diffusion
Expand Down Expand Up @@ -163,16 +163,6 @@ def __init__(
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
self.enable_attention_slicing(None)

def freeze_vae(self):
set_requires_grad(self.vae, False)

Expand Down
61 changes: 3 additions & 58 deletions examples/community/composable_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import (
Expand All @@ -32,13 +33,13 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import deprecate, is_accelerate_available, logging
from diffusers.utils import deprecate, logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class ComposableStableDiffusionPipeline(DiffusionPipeline):
class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.

Expand Down Expand Up @@ -164,62 +165,6 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.

When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

device = torch.device(f"cuda:{gpu_id}")

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)

if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)

@property
def _execution_device(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would leave this as is in the Pipeline for now and not add to the Mixin.

Copy link
Contributor Author

@ultranity ultranity Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize some of those changes should be spilt into a single commit(like: remove unnecessary overload ...), but as they have been inherited from DiffusionPipeline those function enable_sequential_cpu_offload, _execution_device, enable_attention_slicing could be removed without changing any behaviour?

r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device

def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Expand Down
58 changes: 2 additions & 56 deletions examples/community/gluegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from diffusers.loaders import LoraLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
Expand Down Expand Up @@ -193,7 +194,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class GlueGenStableDiffusionPipeline(DiffusionPipeline, LoraLoaderMixin):
class GlueGenStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin, LoraLoaderMixin):
def __init__(
self,
vae: AutoencoderKL,
Expand Down Expand Up @@ -241,35 +242,6 @@ def load_language_adapter(
)
self.language_adapter.load_state_dict(torch.load(model_path))

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

def _adapt_language(self, prompt_embeds: torch.FloatTensor):
prompt_embeds = prompt_embeds / 3
prompt_embeds = self.language_adapter(prompt_embeds) * (self.tensor_norm / 2)
Expand Down Expand Up @@ -544,32 +516,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents

def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.

The suffixes after the scaling factors represent the stages where they are being applied.

Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)

def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()

# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
Expand Down
28 changes: 2 additions & 26 deletions examples/community/imagic_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
Expand Down Expand Up @@ -56,7 +57,7 @@ def preprocess(image):
return 2.0 * image - 1.0


class ImagicStableDiffusionPipeline(DiffusionPipeline):
class ImagicStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
r"""
Pipeline for imagic image editing.
See paper here: https://arxiv.org/pdf/2210.09276.pdf
Expand Down Expand Up @@ -105,31 +106,6 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def train(
self,
prompt: Union[str, List[str]],
Expand Down
27 changes: 0 additions & 27 deletions examples/community/img2img_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,33 +129,6 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.

When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.

Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

@torch.no_grad()
def __call__(
self,
Expand Down
Loading