From 1c7c6ba5184a0d30d9992b7a71f5c6d4a8090413 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 28 Jun 2023 12:42:27 +0200 Subject: [PATCH 1/3] Model offload. --- .../pipeline_stable_diffusion_xl.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 271d322c896a..0919c83a0994 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -228,12 +228,18 @@ def enable_model_cpu_offload(self, gpu_id=0): self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + for cpu_offloaded_model in model_sequence: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + # TODO: safety_checker + # if self.safety_checker is not None: + # _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook From 0393c4d4f3bad6d44c7d176f30ed5a722a396aa5 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 28 Jun 2023 13:17:32 +0200 Subject: [PATCH 2/3] Model offload for refiner / img2img --- .../pipeline_stable_diffusion_xl_img2img.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 8104dba2726d..820e85d59d0b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -233,12 +233,18 @@ def enable_model_cpu_offload(self, gpu_id=0): self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + for cpu_offloaded_model in model_sequence: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + # TODO: safety_checker + # if self.safety_checker is not None: + # _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook From 6fd4aaa802775ba72e8328a29d4d509c04527e5c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 28 Jun 2023 18:52:14 +0200 Subject: [PATCH 3/3] Hardcode encoder offload on img2img vae encode Saves some GPU RAM in img2img / refiner tasks so it remains below 8 GB. --- .../pipeline_stable_diffusion_xl_img2img.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 820e85d59d0b..c2fe22e5e932 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -532,6 +532,11 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt