Skip to content

Commit 196aef5

Browse files
Fix pipeline dtype unexpected change when using SDXL reference community pipelines in float16 mode (#10670)
Fix pipeline dtype unexpected change when using SDXL reference community pipelines
1 parent 7b100ce commit 196aef5

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

examples/community/stable_diffusion_xl_controlnet_reference.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
193193

194194
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
195195
refimage = refimage.to(device=device)
196-
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
196+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
197+
if needs_upcasting:
197198
self.upcast_vae()
198199
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
199200
if refimage.dtype != self.vae.dtype:
@@ -223,6 +224,11 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do
223224

224225
# aligning device to prevent device errors when concating it with the latent model input
225226
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
227+
228+
# cast back to fp16 if needed
229+
if needs_upcasting:
230+
self.vae.to(dtype=torch.float16)
231+
226232
return ref_image_latents
227233

228234
def prepare_ref_image(

examples/community/stable_diffusion_xl_reference.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def retrieve_timesteps(
139139
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
140140
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
141141
refimage = refimage.to(device=device)
142-
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
142+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
143+
if needs_upcasting:
143144
self.upcast_vae()
144145
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
145146
if refimage.dtype != self.vae.dtype:
@@ -169,6 +170,11 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do
169170

170171
# aligning device to prevent device errors when concating it with the latent model input
171172
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
173+
174+
# cast back to fp16 if needed
175+
if needs_upcasting:
176+
self.vae.to(dtype=torch.float16)
177+
172178
return ref_image_latents
173179

174180
def prepare_ref_image(

0 commit comments

Comments
 (0)