diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 8a0e770d037..dd94ee0e900 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -572,6 +572,8 @@ def get_variant_type(self) -> ModelVariantType: if in_channels is None: # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning. + # If this occurs, we should add a test case for the affected model here: + # tests/backend/flux/test_flux_state_dict_utils.py logger.warning( f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant." ) diff --git a/tests/backend/flux/test_flux_state_dict_utils.py b/tests/backend/flux/test_flux_state_dict_utils.py new file mode 100644 index 00000000000..c4540ef0d22 --- /dev/null +++ b/tests/backend/flux/test_flux_state_dict_utils.py @@ -0,0 +1,35 @@ +from pathlib import Path + +import pytest + +from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict +from invokeai.backend.model_manager.config import ModelOnDisk + +test_cases = [ + # Unquantized + ("FLUX Dev.safetensors", 64), + ("FLUX Schnell.safetensors", 64), + ("FLUX Fill.safetensors", 384), + # BNB-NF4 quantized + ("FLUX Dev (Quantized).safetensors", 1), # BNB-NF4 + ("FLUX Schnell (Quantized).safetensors", 1), # BNB-NF4 + # GGUF quantized FLUX Fill + ("flux1-fill-dev-Q8_0.gguf", 384), + # Fine-tune w/ "model.diffusion_model.img_in.weight" instead of "img_in.weight" + ("midjourneyReplica_flux1Dev.safetensors", 64), + # Not a FLUX model, testing fallback case + ("Noodles Style.safetensors", None), +] + + +@pytest.mark.parametrize("model_file_name,expected_in_channels", test_cases) +def test_get_flux_in_channels_from_state_dict(model_file_name: str, expected_in_channels: int, override_model_loading): + model_path = Path(f"tests/test_model_probe/stripped_models/{model_file_name}") + + mod = ModelOnDisk(model_path) + + state_dict = mod.load_state_dict() + + in_channels = get_flux_in_channels_from_state_dict(state_dict) + + assert in_channels == expected_in_channels diff --git a/tests/conftest.py b/tests/conftest.py index b112a4ff2e9..aa21aa88621 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,9 @@ def override_model_loading(monkeypatch): monkeypatch.setattr(safetensors.torch, "load", load_stripped_model) monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model) monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model) + monkeypatch.setattr("invokeai.backend.model_manager.config.gguf_sd_loader", load_stripped_model) + monkeypatch.setattr("invokeai.backend.model_manager.util.model_util.gguf_sd_loader", load_stripped_model) + monkeypatch.setattr("invokeai.backend.model_manager.legacy_probe.gguf_sd_loader", load_stripped_model) def fake_scan(*args, **kwargs): return SimpleNamespace(infected_files=0, scan_err=None) diff --git a/tests/test_model_probe/stripped_models/FLUX Dev (Quantized).safetensors b/tests/test_model_probe/stripped_models/FLUX Dev (Quantized).safetensors new file mode 100644 index 00000000000..e8646cb043a --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Dev (Quantized).safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe25212279fec351340d1c4a9da0eb902af82162350970c148bf331c1c02f3c5 +size 292730 diff --git a/tests/test_model_probe/stripped_models/FLUX Dev.safetensors b/tests/test_model_probe/stripped_models/FLUX Dev.safetensors new file mode 100644 index 00000000000..32718fac409 --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Dev.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84850676ab6fc163b4fe3bb87b1584a5a78b523e5f6e58b6ecb2c7d34e4c0796 +size 130743 diff --git a/tests/test_model_probe/stripped_models/FLUX Fill.safetensors b/tests/test_model_probe/stripped_models/FLUX Fill.safetensors new file mode 100644 index 00000000000..d15e6cb0e0b --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Fill.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb64744f32674cd1e8c3c09e578d18e1ca84c3deac0ef0a2fc3654ec9ac0a84d +size 130744 diff --git a/tests/test_model_probe/stripped_models/FLUX Schnell (Quantized).safetensors b/tests/test_model_probe/stripped_models/FLUX Schnell (Quantized).safetensors new file mode 100644 index 00000000000..30688e9c339 --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Schnell (Quantized).safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42cd75dbd5dec6252de6f959a6ed678fb0e5bef166eca7ac38c51577a0d4e4eb +size 291091 diff --git a/tests/test_model_probe/stripped_models/FLUX Schnell.safetensors b/tests/test_model_probe/stripped_models/FLUX Schnell.safetensors new file mode 100644 index 00000000000..4f46a9fe198 --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Schnell.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1533dced878ca5a8bae39bfdbed85dfd97e937ec3c97540da1e7d4011ffed98 +size 130098 diff --git a/tests/test_model_probe/stripped_models/flux1-fill-dev-Q8_0.gguf b/tests/test_model_probe/stripped_models/flux1-fill-dev-Q8_0.gguf new file mode 100644 index 00000000000..deabac76c9e --- /dev/null +++ b/tests/test_model_probe/stripped_models/flux1-fill-dev-Q8_0.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cac069dd904e0d676baacecfeaba52bbbe808a6d755dabdd94c7281656fa0507 +size 129356 diff --git a/tests/test_model_probe/stripped_models/midjourneyReplica_flux1Dev.safetensors b/tests/test_model_probe/stripped_models/midjourneyReplica_flux1Dev.safetensors new file mode 100644 index 00000000000..9fd14405496 --- /dev/null +++ b/tests/test_model_probe/stripped_models/midjourneyReplica_flux1Dev.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98d0f54489ec096f543a9b8f88683fd960acd96521d987e027be9e23d621d96f +size 151803