Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: flux variant probe testing #7864

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/backend/model_manager/legacy_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ def get_variant_type(self) -> ModelVariantType:

if in_channels is None:
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
# If this occurs, we should add a test case for the affected model here:
# tests/backend/flux/test_flux_state_dict_utils.py
logger.warning(
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
)
Expand Down
35 changes: 35 additions & 0 deletions tests/backend/flux/test_flux_state_dict_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pathlib import Path

import pytest

from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
from invokeai.backend.model_manager.config import ModelOnDisk

test_cases = [
# Unquantized
("FLUX Dev.safetensors", 64),
("FLUX Schnell.safetensors", 64),
("FLUX Fill.safetensors", 384),
# BNB-NF4 quantized
("FLUX Dev (Quantized).safetensors", 1), # BNB-NF4
("FLUX Schnell (Quantized).safetensors", 1), # BNB-NF4
# GGUF quantized FLUX Fill
("flux1-fill-dev-Q8_0.gguf", 384),
# Fine-tune w/ "model.diffusion_model.img_in.weight" instead of "img_in.weight"
("midjourneyReplica_flux1Dev.safetensors", 64),
# Not a FLUX model, testing fallback case
("Noodles Style.safetensors", None),
]


@pytest.mark.parametrize("model_file_name,expected_in_channels", test_cases)
def test_get_flux_in_channels_from_state_dict(model_file_name: str, expected_in_channels: int, override_model_loading):
model_path = Path(f"tests/test_model_probe/stripped_models/{model_file_name}")

mod = ModelOnDisk(model_path)

state_dict = mod.load_state_dict()

in_channels = get_flux_in_channels_from_state_dict(state_dict)

assert in_channels == expected_in_channels
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def override_model_loading(monkeypatch):
monkeypatch.setattr(safetensors.torch, "load", load_stripped_model)
monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model)
monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model)
monkeypatch.setattr("invokeai.backend.model_manager.config.gguf_sd_loader", load_stripped_model)
monkeypatch.setattr("invokeai.backend.model_manager.util.model_util.gguf_sd_loader", load_stripped_model)
monkeypatch.setattr("invokeai.backend.model_manager.legacy_probe.gguf_sd_loader", load_stripped_model)

def fake_scan(*args, **kwargs):
return SimpleNamespace(infected_files=0, scan_err=None)
Expand Down
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/test_model_probe/stripped_models/FLUX Dev.safetensors
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/test_model_probe/stripped_models/FLUX Fill.safetensors
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading