Skip to content

Commit 7a9a7c2

Browse files
committed
Fix AutoencoderTiny encoder scaling convention
* Add [-1, 1] -> [0, 1] rescaling to EncoderTiny (this fixes huggingface#4676) * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny (i.e. immediately after the final conv, as early as possible) * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward * Update AutoencoderTinyIntegrationTests to protect against scaling issues. The new test constructs a simple image, round-trips it through AutoencoderTiny, and confirms the decoded result is approximately equal to the source image. This test checks behavior with and without tiling enabled. This test will fail if new AutoencoderTiny scaling issues are introduced. * Context: Raw TAESD weights expect images in [0, 1], but diffusers' convention represents images with zero-centered values in [-1, 1], so AutoencoderTiny needs to scale / unscale images at the start of encoding and at the end of decoding in order to work with diffusers.
1 parent 74d902e commit 7a9a7c2

File tree

3 files changed

+20
-21
lines changed

3 files changed

+20
-21
lines changed

src/diffusers/models/autoencoder_tiny.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,6 @@ def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
312312
output = torch.cat(output)
313313
else:
314314
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
315-
# Refer to the following discussion to know why this is needed.
316-
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
317-
output = output.mul_(2).sub_(1)
318315

319316
if not return_dict:
320317
return (output,)
@@ -334,7 +331,7 @@ def forward(
334331
"""
335332
enc = self.encode(sample).latents
336333
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
337-
unscaled_enc = self.unscale_latents(scaled_enc)
334+
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
338335
dec = self.decode(unscaled_enc)
339336

340337
if not return_dict:

src/diffusers/models/vae.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,8 @@ def custom_forward(*inputs):
732732
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
733733

734734
else:
735-
x = self.layers(x)
735+
# scale image from [-1, 1] to [0, 1] to match TAESD convention
736+
x = self.layers(x.add(1).div(2))
736737

737738
return x
738739

@@ -790,4 +791,5 @@ def custom_forward(*inputs):
790791
else:
791792
x = self.layers(x)
792793

793-
return x
794+
# scale image from [0, 1] to [-1, 1] to match diffusers convention
795+
return x.mul(2).sub(1)

tests/models/test_models_vae.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,6 @@ def tearDown(self):
270270
gc.collect()
271271
torch.cuda.empty_cache()
272272

273-
def get_file_format(self, seed, shape):
274-
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
275-
276-
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
277-
dtype = torch.float16 if fp16 else torch.float32
278-
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
279-
return image
280-
281273
def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
282274
torch_dtype = torch.float16 if fp16 else torch.float32
283275

@@ -302,19 +294,27 @@ def test_tae_tiling(self, in_shape, out_shape):
302294
dec = model.decode(zeros).sample
303295
assert dec.shape == out_shape
304296

305-
def test_stable_diffusion(self):
297+
@parameterized.expand([True, False])
298+
def test_tae_roundtrip(self, enable_tiling):
299+
# load the autoencoder
306300
model = self.get_sd_vae_model()
307-
image = self.get_sd_image(seed=33)
301+
if enable_tiling:
302+
model.enable_tiling()
303+
304+
# make a black image with a white square in the middle,
305+
# which is large enough to split across multiple tiles
306+
image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
307+
image[..., 256:768, 256:768] = 1.0
308308

309+
# round-trip the image through the autoencoder
309310
with torch.no_grad():
310311
sample = model(image).sample
311312

312-
assert sample.shape == image.shape
313-
314-
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
315-
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604])
313+
# the autoencoder reconstruction should match original image, sorta
314+
def downscale(x):
315+
return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
316+
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
316317

317-
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
318318

319319

320320
@slow

0 commit comments

Comments
 (0)