|
| 1 | +from pathlib import Path |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict |
| 6 | +from invokeai.backend.model_manager.config import ModelOnDisk |
| 7 | + |
| 8 | +test_cases = [ |
| 9 | + # Unquantized |
| 10 | + ("FLUX Dev.safetensors", 64), |
| 11 | + ("FLUX Schnell.safetensors", 64), |
| 12 | + ("FLUX Fill.safetensors", 384), |
| 13 | + # BNB-NF4 quantized |
| 14 | + ("FLUX Dev (Quantized).safetensors", 1), # BNB-NF4 |
| 15 | + ("FLUX Schnell (Quantized).safetensors", 1), # BNB-NF4 |
| 16 | + # GGUF quantized FLUX Fill |
| 17 | + ("flux1-fill-dev-Q8_0.gguf", 384), |
| 18 | + # Fine-tune w/ "model.diffusion_model.img_in.weight" instead of "img_in.weight" |
| 19 | + ("midjourneyReplica_flux1Dev.safetensors", 64), |
| 20 | + # Not a FLUX model, testing fallback case |
| 21 | + ("Noodles Style.safetensors", None), |
| 22 | +] |
| 23 | + |
| 24 | + |
| 25 | +@pytest.mark.parametrize("model_file_name,expected_in_channels", test_cases) |
| 26 | +def test_get_flux_in_channels_from_state_dict(model_file_name: str, expected_in_channels: int, override_model_loading): |
| 27 | + model_path = Path(f"tests/test_model_probe/stripped_models/{model_file_name}") |
| 28 | + |
| 29 | + mod = ModelOnDisk(model_path) |
| 30 | + |
| 31 | + state_dict = mod.load_state_dict() |
| 32 | + |
| 33 | + in_channels = get_flux_in_channels_from_state_dict(state_dict) |
| 34 | + |
| 35 | + assert in_channels == expected_in_channels |
0 commit comments