Skip to content

Commit a44bfb4

Browse files
fix(mm): handle FLUX models w/ diff in_channels keys
Before FLUX Fill was merged, we didn't do any checks for the model variant. We always returned "normal". To determine if a model is a FLUX Fill model, we need to check the state dict for a specific key. Initially, this logic was too strict and rejected quantized FLUX models. This issue was resolved, but it turns out there is another failure mode - some fine-tunes use a different key. This change further reduces the strictness, handling the alternate key and also falling back to "normal" if we don't see either key. This effectively restores the previous probing behaviour for all FLUX models. Closes #7856 Closes #7859
1 parent 96fb5f6 commit a44bfb4

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import TYPE_CHECKING
2+
3+
if TYPE_CHECKING:
4+
from invokeai.backend.model_manager.legacy_probe import CkptType
5+
6+
7+
def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None:
8+
"""Gets the in channels from the state dict."""
9+
10+
# "Standard" FLUX models use "img_in.weight", but some community fine tunes use
11+
# "model.diffusion_model.img_in.weight". Known models that use the latter key:
12+
# - https://civitai.com/models/885098?modelVersionId=990775
13+
# - https://civitai.com/models/1018060?modelVersionId=1596255
14+
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
15+
16+
keys = {"img_in.weight", "model.diffusion_model.img_in.weight"}
17+
18+
for key in keys:
19+
val = state_dict.get(key)
20+
if val is not None:
21+
return val.shape[1]
22+
23+
return None

invokeai/backend/model_manager/legacy_probe.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
is_state_dict_instantx_controlnet,
1515
is_state_dict_xlabs_controlnet,
1616
)
17+
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
1718
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
1819
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
1920
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
@@ -564,7 +565,14 @@ def get_variant_type(self) -> ModelVariantType:
564565
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
565566

566567
if base_type == BaseModelType.Flux:
567-
in_channels = state_dict["img_in.weight"].shape[1]
568+
in_channels = get_flux_in_channels_from_state_dict(state_dict)
569+
570+
if in_channels is None:
571+
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
572+
logger.warning(
573+
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
574+
)
575+
return ModelVariantType.Normal
568576

569577
# FLUX Model variant types are distinguished by input channels:
570578
# - Unquantized Dev and Schnell have in_channels=64

0 commit comments

Comments
 (0)