|
18 | 18 | from huggingface_hub.utils import validate_hf_hub_args
|
19 | 19 |
|
20 | 20 | from ..configuration_utils import ConfigMixin
|
| 21 | +from ..models.controlnets import ControlNetUnionModel |
21 | 22 | from ..utils import is_sentencepiece_available
|
22 | 23 | from .aura_flow import AuraFlowPipeline
|
23 | 24 | from .cogview3 import CogView3PlusPipeline
|
|
28 | 29 | StableDiffusionXLControlNetImg2ImgPipeline,
|
29 | 30 | StableDiffusionXLControlNetInpaintPipeline,
|
30 | 31 | StableDiffusionXLControlNetPipeline,
|
| 32 | + StableDiffusionXLControlNetUnionImg2ImgPipeline, |
| 33 | + StableDiffusionXLControlNetUnionInpaintPipeline, |
| 34 | + StableDiffusionXLControlNetUnionPipeline, |
31 | 35 | )
|
32 | 36 | from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
33 | 37 | from .flux import (
|
|
108 | 112 | ("kandinsky3", Kandinsky3Pipeline),
|
109 | 113 | ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
|
110 | 114 | ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
|
| 115 | + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline), |
111 | 116 | ("wuerstchen", WuerstchenCombinedPipeline),
|
112 | 117 | ("cascade", StableCascadeCombinedPipeline),
|
113 | 118 | ("lcm", LatentConsistencyModelPipeline),
|
|
139 | 144 | ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
140 | 145 | ("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
|
141 | 146 | ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
| 147 | + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline), |
142 | 148 | ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
143 | 149 | ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
|
144 | 150 | ("lcm", LatentConsistencyModelImg2ImgPipeline),
|
|
158 | 164 | ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
|
159 | 165 | ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
|
160 | 166 | ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
| 167 | + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline), |
161 | 168 | ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
162 | 169 | ("flux", FluxInpaintPipeline),
|
163 | 170 | ("flux-controlnet", FluxControlNetInpaintPipeline),
|
@@ -396,7 +403,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
|
396 | 403 | orig_class_name = config["_class_name"]
|
397 | 404 |
|
398 | 405 | if "controlnet" in kwargs:
|
399 |
| - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
| 406 | + if isinstance(kwargs["controlnet"], ControlNetUnionModel): |
| 407 | + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline") |
| 408 | + else: |
| 409 | + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
400 | 410 | if "enable_pag" in kwargs:
|
401 | 411 | enable_pag = kwargs.pop("enable_pag")
|
402 | 412 | if enable_pag:
|
@@ -688,7 +698,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
|
688 | 698 | to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
|
689 | 699 |
|
690 | 700 | if "controlnet" in kwargs:
|
691 |
| - orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) |
| 701 | + if isinstance(kwargs["controlnet"], ControlNetUnionModel): |
| 702 | + orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) |
| 703 | + else: |
| 704 | + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) |
692 | 705 | if "enable_pag" in kwargs:
|
693 | 706 | enable_pag = kwargs.pop("enable_pag")
|
694 | 707 | if enable_pag:
|
@@ -985,7 +998,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
|
985 | 998 | to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
|
986 | 999 |
|
987 | 1000 | if "controlnet" in kwargs:
|
988 |
| - orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) |
| 1001 | + if isinstance(kwargs["controlnet"], ControlNetUnionModel): |
| 1002 | + orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) |
| 1003 | + else: |
| 1004 | + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) |
989 | 1005 | if "enable_pag" in kwargs:
|
990 | 1006 | enable_pag = kwargs.pop("enable_pag")
|
991 | 1007 | if enable_pag:
|
|
0 commit comments