Skip to content

Commit 8f2253c

Browse files
hlkysayakpaulyiyixuxu
authored
Add torch_xla and from_single_file to instruct-pix2pix (#10444)
* Add torch_xla and from_single_file to instruct-pix2pix * StableDiffusionInstructPix2PixPipelineSingleFileSlowTests * StableDiffusionInstructPix2PixPipelineSingleFileSlowTests --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 7747b58 commit 8f2253c

File tree

3 files changed

+65
-3
lines changed

3 files changed

+65
-3
lines changed

Diff for: src/diffusers/loaders/single_file_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
110110
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
111111
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
112+
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
112113
}
113114

114115
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -165,6 +166,7 @@
165166
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
166167
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
167168
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
169+
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
168170
}
169171

170172
# Use to configure model sample size when original config is provided
@@ -633,6 +635,12 @@ def infer_diffusers_model_type(checkpoint):
633635
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
634636
model_type = "hunyuan-video"
635637

638+
elif (
639+
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
640+
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
641+
):
642+
model_type = "instruct-pix2pix"
643+
636644
else:
637645
model_type = "v1"
638646

Diff for: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,23 @@
2222

2323
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2424
from ...image_processor import PipelineImageInput, VaeImageProcessor
25-
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
25+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
2626
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
2727
from ...schedulers import KarrasDiffusionSchedulers
28-
from ...utils import PIL_INTERPOLATION, deprecate, logging
28+
from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
2929
from ...utils.torch_utils import randn_tensor
3030
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3131
from . import StableDiffusionPipelineOutput
3232
from .safety_checker import StableDiffusionSafetyChecker
3333

3434

35+
if is_torch_xla_available():
36+
import torch_xla.core.xla_model as xm
37+
38+
XLA_AVAILABLE = True
39+
else:
40+
XLA_AVAILABLE = False
41+
3542
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3643

3744

@@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline(
7986
TextualInversionLoaderMixin,
8087
StableDiffusionLoraLoaderMixin,
8188
IPAdapterMixin,
89+
FromSingleFileMixin,
8290
):
8391
r"""
8492
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
@@ -457,6 +465,9 @@ def __call__(
457465
step_idx = i // getattr(self.scheduler, "order", 1)
458466
callback(step_idx, t, latents)
459467

468+
if XLA_AVAILABLE:
469+
xm.mark_step()
470+
460471
if not output_type == "latent":
461472
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
462473
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

Diff for: tests/single_file/test_stable_diffusion_single_file.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
import torch
66

7-
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
7+
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
88
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
9+
from diffusers.utils import load_image
910
from diffusers.utils.testing_utils import (
1011
backend_empty_cache,
1112
enable_full_determinism,
13+
nightly,
1214
require_torch_accelerator,
1315
slow,
1416
torch_device,
@@ -118,3 +120,44 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
118120

119121
def test_single_file_format_inference_is_same_as_pretrained(self):
120122
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
123+
124+
125+
@nightly
126+
@slow
127+
@require_torch_accelerator
128+
class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
129+
pipeline_class = StableDiffusionInstructPix2PixPipeline
130+
ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
131+
original_config = (
132+
"https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/refs/heads/main/configs/generate.yaml"
133+
)
134+
repo_id = "timbrooks/instruct-pix2pix"
135+
136+
def setUp(self):
137+
super().setUp()
138+
gc.collect()
139+
backend_empty_cache(torch_device)
140+
141+
def tearDown(self):
142+
super().tearDown()
143+
gc.collect()
144+
backend_empty_cache(torch_device)
145+
146+
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
147+
generator = torch.Generator(device=generator_device).manual_seed(seed)
148+
image = load_image(
149+
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_pix2pix/example.jpg"
150+
)
151+
inputs = {
152+
"prompt": "turn him into a cyborg",
153+
"image": image,
154+
"generator": generator,
155+
"num_inference_steps": 3,
156+
"guidance_scale": 7.5,
157+
"image_guidance_scale": 1.0,
158+
"output_type": "np",
159+
}
160+
return inputs
161+
162+
def test_single_file_format_inference_is_same_as_pretrained(self):
163+
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)

0 commit comments

Comments
 (0)