Skip to content

Commit feed44b

Browse files
Stripped models (#7797)
## Summary **Problem** We want to have automated tests for model classification/probing, but model files are too large to include in the source. **Proposed Solution** Classification/probing only requires metadata (key names, tensor shapes), not weights. This PR introduces "stripped" models - lightweight versions that retains only essential metadata. - Added script to strip models - Added stripped models to automated tests **Model size before and after "stripping":** ``` LLaVA Onevision Qwen2 0.5b-ov-hf before: 1.8 GB, after: 11.6 MB text_encoder before: 246.1 MB, after: 35.6 kB llava-onevision-qwen2-7b-si-hf before: 16.1 GB, after: 11.7 MB RealESRGAN_x2plus.pth before: 67.1 MB, after: 143.0 kB IP Adapter SD1 before: 2.5 GB, after: 94.9 kB Hard Edge Detection (canny) before: 722.6 MB, after: 63.6 kB Lineart before: 722.6 MB, after: 63.6 kB Segmentation Map before: 722.6 MB, after: 63.6 kB EasyNegative before: 24.7 kB, after: 151 Bytes Face Reference (IP Adapter Plus Face) before: 98.2 MB, after: 13.7 kB Standard Reference (IP Adapter) before: 44.6 MB, after: 6.0 kB shinkai_makoto_offset before: 151.1 MB, after: 160.0 kB thickline_fp16 before: 151.1 MB, after: 160.0 kB Alien Style before: 228.5 MB, after: 582.6 kB Noodles Style before: 228.5 MB, after: 582.6 kB Juggernaut XL v9 before: 6.9 GB, after: 3.7 MB dreamshaper-8 before: 168.9 MB, after: 1.6 MB ``` ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 7fe4d4c + 8e14f9d commit feed44b

File tree

9 files changed

+211
-21
lines changed

9 files changed

+211
-21
lines changed

Diff for: invokeai/backend/model_manager/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
AnyModelConfig,
66
BaseModelType,
77
InvalidModelConfigException,
8+
ModelConfigBase,
89
ModelConfigFactory,
910
ModelFormat,
1011
ModelRepoVariant,
@@ -32,4 +33,5 @@
3233
"ModelVariantType",
3334
"SchedulerPredictionType",
3435
"SubModelType",
36+
"ModelConfigBase",
3537
]

Diff for: invokeai/backend/model_manager/config.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,26 @@
2525
import time
2626
from abc import ABC, abstractmethod
2727
from enum import Enum
28-
from functools import cached_property
2928
from inspect import isabstract
3029
from pathlib import Path
3130
from typing import ClassVar, Literal, Optional, TypeAlias, Union
3231

3332
import diffusers
3433
import onnxruntime as ort
34+
import safetensors.torch
3535
import torch
3636
from diffusers.models.modeling_utils import ModelMixin
37+
from picklescan.scanner import scan_file_path
3738
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
3839
from typing_extensions import Annotated, Any, Dict
3940

4041
from invokeai.app.util.misc import uuid_string
4142
from invokeai.backend.model_hash.hash_validator import validate_hash
4243
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
44+
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
4345
from invokeai.backend.raw_model import RawModel
4446
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
47+
from invokeai.backend.util.silence_warnings import SilenceWarnings
4548

4649
logger = logging.getLogger(__name__)
4750

@@ -215,12 +218,37 @@ def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
215218
self.name = path.name
216219
self.hash_algo = hash_algo
217220

218-
@cached_property
219221
def hash(self):
220222
return ModelHash(algorithm=self.hash_algo).hash(self.path)
221223

222-
def lazy_load_state_dict(self) -> dict[str, torch.Tensor]:
223-
raise NotImplementedError()
224+
def size(self):
225+
if self.format_type == ModelFormat.Checkpoint:
226+
return self.path.stat().st_size
227+
return sum(file.stat().st_size for file in self.path.rglob("*"))
228+
229+
def component_paths(self):
230+
if self.format_type == ModelFormat.Checkpoint:
231+
return {self.path}
232+
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
233+
return {f for f in self.path.rglob("*") if f.suffix in extensions}
234+
235+
@staticmethod
236+
def load_state_dict(path: Path):
237+
with SilenceWarnings():
238+
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
239+
scan_result = scan_file_path(path)
240+
if scan_result.infected_files != 0 or scan_result.scan_err:
241+
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
242+
checkpoint = torch.load(path, map_location="cpu")
243+
elif path.suffix.endswith(".gguf"):
244+
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
245+
elif path.suffix.endswith(".safetensors"):
246+
checkpoint = safetensors.torch.load_file(path)
247+
else:
248+
raise ValueError(f"Unrecognized model extension: {path.suffix}")
249+
250+
state_dict = checkpoint.get("state_dict", checkpoint)
251+
return state_dict
224252

225253

226254
class MatchSpeed(int, Enum):
@@ -343,7 +371,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
343371
fields["source"] = fields.get("source") or fields["path"]
344372
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
345373
fields["name"] = mod.name
346-
fields["hash"] = fields.get("hash") or mod.hash
374+
fields["hash"] = fields.get("hash") or mod.hash()
347375

348376
fields.update(overrides)
349377
return cls(**fields)

Diff for: invokeai/backend/model_manager/legacy_probe.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from pathlib import Path
44
from typing import Any, Callable, Dict, Literal, Optional, Union
55

6+
import picklescan.scanner as pscan
67
import safetensors.torch
78
import spandrel
89
import torch
9-
from picklescan.scanner import scan_file_path
1010

1111
import invokeai.backend.util.logging as logger
1212
from invokeai.app.util.misc import uuid_string
@@ -483,7 +483,7 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
483483
and option to exit if an infected file is identified.
484484
"""
485485
# scan model
486-
scan_result = scan_file_path(checkpoint)
486+
scan_result = pscan.scan_file_path(checkpoint)
487487
if scan_result.infected_files != 0:
488488
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
489489
if scan_result.scan_err:

Diff for: invokeai/backend/model_manager/util/model_util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from pathlib import Path
55
from typing import Dict, Optional, Union
66

7+
import picklescan.scanner as pscan
78
import safetensors
89
import torch
9-
from picklescan.scanner import scan_file_path
1010

1111
from invokeai.backend.model_manager.config import ClipVariantType
1212
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -57,7 +57,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
5757
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
5858
else:
5959
if scan:
60-
scan_result = scan_file_path(path)
60+
scan_result = pscan.scan_file_path(path)
6161
if scan_result.infected_files != 0:
6262
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
6363
if scan_result.scan_err:

Diff for: pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ dependencies = [
9595
"semver~=3.0.1",
9696
"test-tube",
9797
"windows-curses; sys_platform=='win32'",
98+
"humanize==4.12.1",
9899
]
99100

100101
[project.optional-dependencies]
@@ -103,6 +104,7 @@ dependencies = [
103104
"xformers>=0.0.28.post1; sys_platform!='darwin'",
104105
# torch 2.4+cu carries its own triton dependency
105106
]
107+
106108
"onnx" = ["onnxruntime"]
107109
"onnx-cuda" = ["onnxruntime-gpu"]
108110
"onnx-directml" = ["onnxruntime-directml"]

Diff for: scripts/probe-model.py renamed to scripts/classify-model.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import get_args
88

99
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
10-
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
10+
from invokeai.backend.model_manager import InvalidModelConfigException, ModelConfigBase, ModelProbe
1111

1212
algos = ", ".join(set(get_args(HASHING_ALGORITHMS)))
1313

@@ -25,9 +25,17 @@
2525
)
2626
args = parser.parse_args()
2727

28+
29+
def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
30+
try:
31+
return ModelConfigBase.classify(path, hash_algo)
32+
except InvalidModelConfigException:
33+
return ModelProbe.probe(path, hash_algo=hash_algo)
34+
35+
2836
for path in args.model_path:
2937
try:
30-
info = ModelProbe.probe(path, hash_algo=args.hash_algo)
31-
print(f"{path}:{info.model_dump_json(indent=4)}")
32-
except InvalidModelConfigException as exc:
33-
print(exc)
38+
config = classify_with_fallback(path, args.hash_algo)
39+
print(f"{path}:{config.model_dump_json(indent=4)}")
40+
except InvalidModelConfigException as e:
41+
print(e)

Diff for: scripts/strip_models.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
Usage:
3+
strip_models.py <models_input_dir> <stripped_output_dir>
4+
5+
Strips tensor data from model state_dicts while preserving metadata.
6+
Used to create lightweight models for testing model classification.
7+
8+
Parameters:
9+
<models_input_dir> Directory containing original models.
10+
<stripped_output_dir> Directory where stripped models will be saved.
11+
12+
Options:
13+
-h, --help Show this help message and exit
14+
"""
15+
16+
import argparse
17+
import json
18+
import shutil
19+
import sys
20+
from pathlib import Path
21+
22+
import humanize
23+
import torch
24+
25+
from invokeai.backend.model_manager.config import ModelFormat, ModelOnDisk
26+
from invokeai.backend.model_manager.search import ModelSearch
27+
28+
29+
def strip(v):
30+
match v:
31+
case torch.Tensor():
32+
return {"shape": v.shape, "dtype": str(v.dtype), "fakeTensor": True}
33+
case dict():
34+
return {k: strip(v) for k, v in v.items()}
35+
case list() | tuple():
36+
return [strip(x) for x in v]
37+
case _:
38+
return v
39+
40+
41+
STR_TO_DTYPE = {str(dtype): dtype for dtype in torch.__dict__.values() if isinstance(dtype, torch.dtype)}
42+
43+
44+
def dress(v):
45+
match v:
46+
case {"shape": shape, "dtype": dtype_str, "fakeTensor": True}:
47+
dtype = STR_TO_DTYPE[dtype_str]
48+
return torch.empty(shape, dtype=dtype)
49+
case dict():
50+
return {k: dress(v) for k, v in v.items()}
51+
case list() | tuple():
52+
return [dress(x) for x in v]
53+
case _:
54+
return v
55+
56+
57+
def load_stripped_model(path: Path, *args, **kwargs):
58+
with open(path, "r") as f:
59+
contents = json.load(f)
60+
return dress(contents)
61+
62+
63+
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
64+
original = ModelOnDisk(original_model_path)
65+
if original.format_type == ModelFormat.Checkpoint:
66+
shutil.copy2(original.path, stripped_model_path)
67+
else:
68+
shutil.copytree(original.path, stripped_model_path, dirs_exist_ok=True)
69+
stripped = ModelOnDisk(stripped_model_path)
70+
print(f"Created clone of {original.name} at {stripped.path}")
71+
72+
for component_path in stripped.component_paths():
73+
original_state_dict = ModelOnDisk.load_state_dict(component_path)
74+
stripped_state_dict = strip(original_state_dict) # type: ignore
75+
with open(component_path, "w") as f:
76+
json.dump(stripped_state_dict, f, indent=4)
77+
78+
before_size = humanize.naturalsize(original.size())
79+
after_size = humanize.naturalsize(stripped.size())
80+
print(f"{original.name} before: {before_size}, after: {after_size}")
81+
82+
return stripped
83+
84+
85+
def parse_arguments():
86+
class Parser(argparse.ArgumentParser):
87+
def error(self, reason):
88+
raise ValueError(reason)
89+
90+
parser = Parser()
91+
parser.add_argument("models_input_dir", type=Path)
92+
parser.add_argument("stripped_output_dir", type=Path)
93+
94+
try:
95+
args = parser.parse_args()
96+
except ValueError as e:
97+
print(f"Error: {e}", file=sys.stderr)
98+
print(__doc__, file=sys.stderr)
99+
sys.exit(2)
100+
101+
if not args.models_input_dir.exists():
102+
parser.error(f"Error: Input models directory '{args.models_input_dir}' does not exist.")
103+
if not args.models_input_dir.is_dir():
104+
parser.error(f"Error: '{args.input_models_dir}' is not a directory.")
105+
106+
return args
107+
108+
109+
if __name__ == "__main__":
110+
args = parse_arguments()
111+
model_paths = sorted(ModelSearch().search(args.models_input_dir))
112+
113+
for path in model_paths:
114+
stripped_path = args.stripped_output_dir / path.name
115+
create_stripped_model(path, stripped_path)

Diff for: tests/conftest.py

+26
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77
import logging
88
import shutil
99
from pathlib import Path
10+
from types import SimpleNamespace
1011

12+
import picklescan.scanner
1113
import pytest
14+
import safetensors.torch
15+
import torch
1216

17+
import invokeai.backend.quantization.gguf.loaders as gguf_loaders
1318
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
1419
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
1520
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
@@ -20,6 +25,7 @@
2025
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
2126
from invokeai.app.services.invoker import Invoker
2227
from invokeai.backend.util.logging import InvokeAILogger
28+
from scripts.strip_models import load_stripped_model
2329
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
2430
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
2531
from tests.test_nodes import TestEventService
@@ -73,3 +79,23 @@ def invokeai_root_dir(tmp_path_factory) -> Path:
7379
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
7480
shutil.copytree(root_template, temp_dir)
7581
return temp_dir
82+
83+
84+
@pytest.fixture(scope="function")
85+
def override_model_loading(monkeypatch):
86+
"""The legacy model probe directly calls model loading functions (e.g. torch.load) and also performs file scanning
87+
via picklescan.scanner.scan_file_path. This fixture replaces these functions with test-friendly versions for
88+
model files that have been 'stripped' to reduce their size (see scripts/strip_models.py).
89+
90+
Ideally, model loading would be injected as a dependency (i.e. ModelOnDisk) - but to avoid modifying the legacy probe,
91+
we monkeypatch as a temporary workaround until the legacy probe is fully deprecated.
92+
"""
93+
monkeypatch.setattr(torch, "load", load_stripped_model)
94+
monkeypatch.setattr(safetensors.torch, "load", load_stripped_model)
95+
monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model)
96+
monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model)
97+
98+
def fake_scan(*args, **kwargs):
99+
return SimpleNamespace(infected_files=0, scan_err=None)
100+
101+
monkeypatch.setattr(picklescan.scanner, "scan_file_path", fake_scan)

0 commit comments

Comments
 (0)