Skip to content

Commit 8a3f46f

Browse files
MrSyeeJimmy
authored and
Jimmy
committed
Add vae tiling and slicing in img2img and inpaint (huggingface#6871)
* Add vae tiling in img2img and inpaint * Add vae tiling not slicing
1 parent 8062f07 commit 8a3f46f

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

+33
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,39 @@ def __init__(
288288
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
289289
self.register_to_config(requires_safety_checker=requires_safety_checker)
290290

291+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
292+
def enable_vae_slicing(self):
293+
r"""
294+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
295+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
296+
"""
297+
self.vae.enable_slicing()
298+
299+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
300+
def disable_vae_slicing(self):
301+
r"""
302+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
303+
computing decoding in one step.
304+
"""
305+
self.vae.disable_slicing()
306+
307+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
308+
def enable_vae_tiling(self):
309+
r"""
310+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
311+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
312+
processing larger images.
313+
"""
314+
self.vae.enable_tiling()
315+
316+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
317+
def disable_vae_tiling(self):
318+
r"""
319+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
320+
computing decoding in one step.
321+
"""
322+
self.vae.disable_tiling()
323+
291324
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
292325
def _encode_prompt(
293326
self,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

+33
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,39 @@ def __init__(
360360
)
361361
self.register_to_config(requires_safety_checker=requires_safety_checker)
362362

363+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
364+
def enable_vae_slicing(self):
365+
r"""
366+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
367+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
368+
"""
369+
self.vae.enable_slicing()
370+
371+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
372+
def disable_vae_slicing(self):
373+
r"""
374+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
375+
computing decoding in one step.
376+
"""
377+
self.vae.disable_slicing()
378+
379+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
380+
def enable_vae_tiling(self):
381+
r"""
382+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
383+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
384+
processing larger images.
385+
"""
386+
self.vae.enable_tiling()
387+
388+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
389+
def disable_vae_tiling(self):
390+
r"""
391+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
392+
computing decoding in one step.
393+
"""
394+
self.vae.disable_tiling()
395+
363396
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
364397
def _encode_prompt(
365398
self,

0 commit comments

Comments
 (0)