Skip to content

Commit a2dbad4

Browse files
committed
Fix AutoencoderTiny encoder scaling convention
* Add [-1, 1] -> [0, 1] rescaling to EncoderTiny (this fixes huggingface#4676) * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny (i.e. immediately after the final conv, as earlier as possible) * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward * Update AutoencoderTinyIntegrationTests to protect against scaling issues. The new test constructs a simple image, round-trips it through AutoencoderTiny, and confirms the decoded result is approximately equal to the source image. This test will fail if new AutoencoderTiny scaling issues are introduced. * Context: Raw TAESD weights expect images in [0, 1], but diffusers' convention represents images with zero-centered values in [-1, 1], so AutoencoderTiny needs to scale / unscale images at the start of encoding and at the end of decoding in order to work with diffusers.
1 parent 74d902e commit a2dbad4

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

src/diffusers/models/autoencoder_tiny.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,6 @@ def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
312312
output = torch.cat(output)
313313
else:
314314
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
315-
# Refer to the following discussion to know why this is needed.
316-
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
317-
output = output.mul_(2).sub_(1)
318315

319316
if not return_dict:
320317
return (output,)
@@ -334,7 +331,7 @@ def forward(
334331
"""
335332
enc = self.encode(sample).latents
336333
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
337-
unscaled_enc = self.unscale_latents(scaled_enc)
334+
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
338335
dec = self.decode(unscaled_enc)
339336

340337
if not return_dict:

src/diffusers/models/vae.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,8 @@ def custom_forward(*inputs):
732732
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
733733

734734
else:
735-
x = self.layers(x)
735+
# scale image from [-1, 1] to [0, 1] to match TAESD convention
736+
x = self.layers(x.mul(0.5).add(0.5))
736737

737738
return x
738739

@@ -790,4 +791,5 @@ def custom_forward(*inputs):
790791
else:
791792
x = self.layers(x)
792793

793-
return x
794+
# scale image from [0, 1] to [-1, 1] to match diffusers convention
795+
return x.mul(2).sub(1)

tests/models/test_models_vae.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -302,19 +302,23 @@ def test_tae_tiling(self, in_shape, out_shape):
302302
dec = model.decode(zeros).sample
303303
assert dec.shape == out_shape
304304

305-
def test_stable_diffusion(self):
305+
def test_roundtrip(self):
306+
# load the autoencoder
306307
model = self.get_sd_vae_model()
307-
image = self.get_sd_image(seed=33)
308308

309+
# make a black image with white square in the middle
310+
image = -torch.ones(1, 3, 512, 512, device=torch_device)
311+
image[..., 128:384, 128:384] = 1.0
312+
313+
# round-trip the image through the autoencoder
309314
with torch.no_grad():
310315
sample = model(image).sample
311316

312-
assert sample.shape == image.shape
317+
# the autoencoder reconstruction should match original image, sorta
318+
def downscale(x):
319+
return torch.nn.functional.avg_pool2d(x, 8)
320+
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
313321

314-
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
315-
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604])
316-
317-
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
318322

319323

320324
@slow

0 commit comments

Comments
 (0)