From b4f4ee93995ddb4c438393556e1e7938de08f7c3 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 18 Dec 2024 14:35:14 +0000 Subject: [PATCH] Add Flux Control to AutoPipeline --- src/diffusers/pipelines/auto_pipeline.py | 37 ++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index a0f95fe6cdc1..f3a05c2c661f 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -35,9 +35,12 @@ ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .flux import ( + FluxControlImg2ImgPipeline, + FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, FluxControlNetInpaintPipeline, FluxControlNetPipeline, + FluxControlPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline, @@ -125,6 +128,7 @@ ("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("auraflow", AuraFlowPipeline), ("flux", FluxPipeline), + ("flux-control", FluxControlPipeline), ("flux-controlnet", FluxControlNetPipeline), ("lumina", LuminaText2ImgPipeline), ("cogview3", CogView3PlusPipeline), @@ -150,6 +154,7 @@ ("lcm", LatentConsistencyModelImg2ImgPipeline), ("flux", FluxImg2ImgPipeline), ("flux-controlnet", FluxControlNetImg2ImgPipeline), + ("flux-control", FluxControlImg2ImgPipeline), ] ) @@ -168,6 +173,7 @@ ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("flux", FluxInpaintPipeline), ("flux-controlnet", FluxControlNetInpaintPipeline), + ("flux-control", FluxControlInpaintPipeline), ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline), ] ) @@ -401,16 +407,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] + if "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline") + orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline") else: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline") if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") + orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline") text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -694,8 +704,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): # the `orig_class_name` can be: # `- *Pipeline` (for regular text-to-image checkpoint) + # - `*ControlPipeline` (for Flux tools specific checkpoint) # `- *Img2ImgPipeline` (for refiner checkpoint) - to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" + if "Img2Img" in orig_class_name: + to_replace = "Img2ImgPipeline" + elif "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): @@ -707,6 +723,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if enable_pag: orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + if to_replace == "ControlPipeline": + orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline") + image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} @@ -994,8 +1013,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): # The `orig_class_name`` can be: # `- *InpaintPipeline` (for inpaint-specific checkpoint) + # - `*ControlPipeline` (for Flux tools specific checkpoint) # - or *Pipeline (for regular text-to-image checkpoint) - to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" + if "Inpaint" in orig_class_name: + to_replace = "InpaintPipeline" + elif "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" if "controlnet" in kwargs: if isinstance(kwargs["controlnet"], ControlNetUnionModel): @@ -1006,6 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): enable_pag = kwargs.pop("enable_pag") if enable_pag: orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + if to_replace == "ControlPipeline": + orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline") inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs}