diff --git a/invokeai/backend/flux/flux_state_dict_utils.py b/invokeai/backend/flux/flux_state_dict_utils.py new file mode 100644 index 00000000000..8ffab54c688 --- /dev/null +++ b/invokeai/backend/flux/flux_state_dict_utils.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from invokeai.backend.model_manager.legacy_probe import CkptType + + +def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None: + """Gets the in channels from the state dict.""" + + # "Standard" FLUX models use "img_in.weight", but some community fine tunes use + # "model.diffusion_model.img_in.weight". Known models that use the latter key: + # - https://civitai.com/models/885098?modelVersionId=990775 + # - https://civitai.com/models/1018060?modelVersionId=1596255 + # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133 + + keys = {"img_in.weight", "model.diffusion_model.img_in.weight"} + + for key in keys: + val = state_dict.get(key) + if val is not None: + return val.shape[1] + + return None diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 304fbce346a..24a5a9f5277 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -14,6 +14,7 @@ is_state_dict_instantx_controlnet, is_state_dict_xlabs_controlnet, ) +from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash @@ -564,7 +565,14 @@ def get_variant_type(self) -> ModelVariantType: state_dict = self.checkpoint.get("state_dict") or self.checkpoint if base_type == BaseModelType.Flux: - in_channels = state_dict["img_in.weight"].shape[1] + in_channels = get_flux_in_channels_from_state_dict(state_dict) + + if in_channels is None: + # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning. + logger.warning( + f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant." + ) + return ModelVariantType.Normal # FLUX Model variant types are distinguished by input channels: # - Unquantized Dev and Schnell have in_channels=64