Skip to content

[ldm3d] Update code to be functional with the new checkpoints #3875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,17 @@ def numpy_to_depth(self, images):
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
raise Exception("Not supported")
images_depth = images[:, :, :, 3:]
if images.shape[-1] == 6:
images_depth = (images_depth * 255).round().astype("uint8")
pil_images = [
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
]
elif images.shape[-1] == 4:
images_depth = (images_depth * 65535.0).astype(np.uint16)
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
else:
pil_images = [Image.fromarray(self.rgblike_to_depthmap(image[:, :, 3:]), mode="I;16") for image in images]
raise Exception("Not supported")

return pil_images

Expand Down Expand Up @@ -349,7 +354,11 @@ def postprocess(
image = self.pt_to_numpy(image)

if output_type == "np":
return image[:, :, :, :3], np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
if image.shape[-1] == 6:
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
else:
image_depth = image[:, :, :, 3:]
return image[:, :, :, :3], image_depth

if output_type == "pil":
return self.numpy_to_pil(image), self.numpy_to_depth(image)
Expand Down
32 changes: 26 additions & 6 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def test_stable_diffusion_ddim(self):
assert depth.shape == (1, 64, 64)

expected_slice_rgb = np.array(
[0.37301102, 0.7023895, 0.7418312, 0.5163375, 0.5825485, 0.60929704, 0.4188174, 0.48407027, 0.46555096]
[0.37338176, 0.70247, 0.74203193, 0.51643604, 0.58256793, 0.60932136, 0.4181095, 0.48355877, 0.46535262]
)
expected_slice_depth = np.array([103.4673, 85.81202, 87.84926])
expected_slice_depth = np.array([103.46727, 85.812004, 87.849236])

assert np.abs(image_slice_rgb.flatten() - expected_slice_rgb).max() < 1e-2
assert np.abs(image_slice_depth.flatten() - expected_slice_depth).max() < 1e-2
Expand Down Expand Up @@ -280,10 +280,30 @@ def test_ldm3d(self):
output = ldm3d_pipe(**inputs)
rgb, depth = output.rgb, output.depth

expected_rgb_mean = 0.54461557
expected_rgb_std = 0.2806707
expected_depth_mean = 143.64595
expected_depth_std = 83.491776
expected_rgb_mean = 0.495586
expected_rgb_std = 0.33795515
expected_depth_mean = 112.48518
expected_depth_std = 98.489746
assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
assert np.abs(expected_depth_std - depth.std()) < 1e-3

def test_ldm3d_v2(self):
ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c").to(torch_device)
ldm3d_pipe.set_progress_bar_config(disable=None)

inputs = self.get_inputs(torch_device)
output = ldm3d_pipe(**inputs)
rgb, depth = output.rgb, output.depth

expected_rgb_mean = 0.4194127
expected_rgb_std = 0.35375586
expected_depth_mean = 0.5638502
expected_depth_std = 0.34686103

assert rgb.shape == (1, 512, 512, 3)
assert depth.shape == (1, 512, 512, 1)
assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
Expand Down