From 20469a82d32368e6856865031142cfce8376d859 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 31 Mar 2025 09:10:45 +1000 Subject: [PATCH] fix(mm): handle FLUX models w/ diff in_channels keys Before FLUX Fill was merged, we didn't do any checks for the model variant. We always returned "normal". To determine if a model is a FLUX Fill model, we need to check the state dict for a specific key. Initially, this logic was too strict and rejected quantized FLUX models. This issue was resolved, but it turns out there is another failure mode - some fine-tunes use a different key. This change further reduces the strictness, handling the alternate key and also falling back to "normal" if we don't see either key. This effectively restores the previous probing behaviour for all FLUX models. Closes #7856 Closes #7859 --- .../backend/flux/flux_state_dict_utils.py | 23 +++++++++++++++++++ .../backend/model_manager/legacy_probe.py | 10 +++++++- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/flux/flux_state_dict_utils.py diff --git a/invokeai/backend/flux/flux_state_dict_utils.py b/invokeai/backend/flux/flux_state_dict_utils.py new file mode 100644 index 00000000000..8ffab54c688 --- /dev/null +++ b/invokeai/backend/flux/flux_state_dict_utils.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from invokeai.backend.model_manager.legacy_probe import CkptType + + +def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None: + """Gets the in channels from the state dict.""" + + # "Standard" FLUX models use "img_in.weight", but some community fine tunes use + # "model.diffusion_model.img_in.weight". Known models that use the latter key: + # - https://civitai.com/models/885098?modelVersionId=990775 + # - https://civitai.com/models/1018060?modelVersionId=1596255 + # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133 + + keys = {"img_in.weight", "model.diffusion_model.img_in.weight"} + + for key in keys: + val = state_dict.get(key) + if val is not None: + return val.shape[1] + + return None diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 304fbce346a..24a5a9f5277 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -14,6 +14,7 @@ is_state_dict_instantx_controlnet, is_state_dict_xlabs_controlnet, ) +from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash @@ -564,7 +565,14 @@ def get_variant_type(self) -> ModelVariantType: state_dict = self.checkpoint.get("state_dict") or self.checkpoint if base_type == BaseModelType.Flux: - in_channels = state_dict["img_in.weight"].shape[1] + in_channels = get_flux_in_channels_from_state_dict(state_dict) + + if in_channels is None: + # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning. + logger.warning( + f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant." + ) + return ModelVariantType.Normal # FLUX Model variant types are distinguished by input channels: # - Unquantized Dev and Schnell have in_channels=64