Skip to content

Commit c3765a9

Browse files
hlkysayakpaul
authored andcommitted
Add ControlNetUnion to AutoPipeline from_pretrained (#10219)
1 parent c09047e commit c3765a9

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

src/diffusers/pipelines/auto_pipeline.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from huggingface_hub.utils import validate_hf_hub_args
1919

2020
from ..configuration_utils import ConfigMixin
21+
from ..models.controlnets import ControlNetUnionModel
2122
from ..utils import is_sentencepiece_available
2223
from .aura_flow import AuraFlowPipeline
2324
from .cogview3 import CogView3PlusPipeline
@@ -28,6 +29,9 @@
2829
StableDiffusionXLControlNetImg2ImgPipeline,
2930
StableDiffusionXLControlNetInpaintPipeline,
3031
StableDiffusionXLControlNetPipeline,
32+
StableDiffusionXLControlNetUnionImg2ImgPipeline,
33+
StableDiffusionXLControlNetUnionInpaintPipeline,
34+
StableDiffusionXLControlNetUnionPipeline,
3135
)
3236
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
3337
from .flux import (
@@ -108,6 +112,7 @@
108112
("kandinsky3", Kandinsky3Pipeline),
109113
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
110114
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
115+
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
111116
("wuerstchen", WuerstchenCombinedPipeline),
112117
("cascade", StableCascadeCombinedPipeline),
113118
("lcm", LatentConsistencyModelPipeline),
@@ -139,6 +144,7 @@
139144
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
140145
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
141146
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
147+
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
142148
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
143149
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
144150
("lcm", LatentConsistencyModelImg2ImgPipeline),
@@ -158,6 +164,7 @@
158164
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
159165
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
160166
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
167+
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
161168
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
162169
("flux", FluxInpaintPipeline),
163170
("flux-controlnet", FluxControlNetInpaintPipeline),
@@ -396,7 +403,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
396403
orig_class_name = config["_class_name"]
397404

398405
if "controlnet" in kwargs:
399-
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
406+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
407+
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
408+
else:
409+
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
400410
if "enable_pag" in kwargs:
401411
enable_pag = kwargs.pop("enable_pag")
402412
if enable_pag:
@@ -688,7 +698,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
688698
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
689699

690700
if "controlnet" in kwargs:
691-
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
701+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
702+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
703+
else:
704+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
692705
if "enable_pag" in kwargs:
693706
enable_pag = kwargs.pop("enable_pag")
694707
if enable_pag:
@@ -985,7 +998,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
985998
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
986999

9871000
if "controlnet" in kwargs:
988-
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
1001+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
1002+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
1003+
else:
1004+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
9891005
if "enable_pag" in kwargs:
9901006
enable_pag = kwargs.pop("enable_pag")
9911007
if enable_pag:

0 commit comments

Comments
 (0)