Skip to content

Commit 514a7f9

Browse files
tests: add test for get_flux_in_channels_from_state_dict()
1 parent bf04c87 commit 514a7f9

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
6+
from invokeai.backend.model_manager.config import ModelOnDisk
7+
8+
test_cases = [
9+
# Unquantized
10+
("FLUX Dev.safetensors", 64),
11+
("FLUX Schnell.safetensors", 64),
12+
("FLUX Fill.safetensors", 384),
13+
# BNB-NF4 quantized
14+
("FLUX Dev (Quantized).safetensors", 1), # BNB-NF4
15+
("FLUX Schnell (Quantized).safetensors", 1), # BNB-NF4
16+
# GGUF quantized FLUX Fill
17+
("flux1-fill-dev-Q8_0.gguf", 384),
18+
# Fine-tune w/ "model.diffusion_model.img_in.weight" instead of "img_in.weight"
19+
("midjourneyReplica_flux1Dev.safetensors", 64),
20+
# Not a FLUX model, testing fallback case
21+
("Noodles Style.safetensors", None),
22+
]
23+
24+
25+
@pytest.mark.parametrize("model_file_name,expected_in_channels", test_cases)
26+
def test_get_flux_in_channels_from_state_dict(model_file_name: str, expected_in_channels: int, override_model_loading):
27+
model_path = Path(f"tests/test_model_probe/stripped_models/{model_file_name}")
28+
29+
mod = ModelOnDisk(model_path)
30+
31+
state_dict = mod.load_state_dict()
32+
33+
in_channels = get_flux_in_channels_from_state_dict(state_dict)
34+
35+
assert in_channels == expected_in_channels

0 commit comments

Comments
 (0)