Skip to content

Commit a53e1cc

Browse files
authored
Small improvements (#7842)
## Summary - Extend `ModelOnDisk` with caching, type hints, default args - Fail early if there is an error classifying a config ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents c6f9661 + 1af9930 commit a53e1cc

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

Diff for: invokeai/backend/model_manager/config.py

+41-18
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ class InvalidModelConfigException(Exception):
6767
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
6868

6969

70+
class FSLayout(Enum):
71+
FILE = "file"
72+
DIRECTORY = "directory"
73+
74+
7075
class SubmodelDefinition(BaseModel):
7176
path_or_prefix: str
7277
model_type: ModelType
@@ -102,29 +107,31 @@ class ModelOnDisk:
102107

103108
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
104109
self.path = path
105-
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
110+
# TODO: Revisit checkpoint vs diffusers terminology
111+
self.layout = FSLayout.DIRECTORY if path.is_dir() else FSLayout.FILE
106112
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
107113
self.name = path.stem
108114
else:
109115
self.name = path.name
110116
self.hash_algo = hash_algo
117+
self._state_dict_cache = {}
111118

112-
def hash(self):
119+
def hash(self) -> str:
113120
return ModelHash(algorithm=self.hash_algo).hash(self.path)
114121

115-
def size(self):
116-
if self.format_type == ModelFormat.Checkpoint:
122+
def size(self) -> int:
123+
if self.layout == FSLayout.FILE:
117124
return self.path.stat().st_size
118125
return sum(file.stat().st_size for file in self.path.rglob("*"))
119126

120-
def component_paths(self):
121-
if self.format_type == ModelFormat.Checkpoint:
127+
def component_paths(self) -> set[Path]:
128+
if self.layout == FSLayout.FILE:
122129
return {self.path}
123130
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
124131
return {f for f in self.path.rglob("*") if f.suffix in extensions}
125132

126-
def repo_variant(self):
127-
if self.format_type == ModelFormat.Checkpoint:
133+
def repo_variant(self) -> Optional[ModelRepoVariant]:
134+
if self.layout == FSLayout.FILE:
128135
return None
129136

130137
weight_files = list(self.path.glob("**/*.safetensors"))
@@ -140,14 +147,30 @@ def repo_variant(self):
140147
return ModelRepoVariant.ONNX
141148
return ModelRepoVariant.Default
142149

143-
@staticmethod
144-
def load_state_dict(path: Path):
150+
def load_state_dict(self, path: Optional[Path] = None) -> Dict[str | int, Any]:
151+
if path in self._state_dict_cache:
152+
return self._state_dict_cache[path]
153+
154+
if not path:
155+
components = list(self.component_paths())
156+
match components:
157+
case []:
158+
raise ValueError("No weight files found for this model")
159+
case [p]:
160+
path = p
161+
case ps if len(ps) >= 2:
162+
raise ValueError(
163+
f"Multiple weight files found for this model: {ps}. "
164+
f"Please specify the intended file using the 'path' argument"
165+
)
166+
145167
with SilenceWarnings():
146168
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
147169
scan_result = scan_file_path(path)
148170
if scan_result.infected_files != 0 or scan_result.scan_err:
149171
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
150172
checkpoint = torch.load(path, map_location="cpu")
173+
assert isinstance(checkpoint, dict)
151174
elif path.suffix.endswith(".gguf"):
152175
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
153176
elif path.suffix.endswith(".safetensors"):
@@ -156,6 +179,7 @@ def load_state_dict(path: Path):
156179
raise ValueError(f"Unrecognized model extension: {path.suffix}")
157180

158181
state_dict = checkpoint.get("state_dict", checkpoint)
182+
self._state_dict_cache[path] = state_dict
159183
return state_dict
160184

161185

@@ -238,11 +262,13 @@ def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single",
238262

239263
for config_cls in sorted_by_match_speed:
240264
try:
241-
return config_cls.from_model_on_disk(mod, **overrides)
242-
except InvalidModelConfigException:
243-
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
265+
if not config_cls.matches(mod):
266+
continue
244267
except Exception as e:
245-
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
268+
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
269+
continue
270+
else:
271+
return config_cls.from_model_on_disk(mod, **overrides)
246272

247273
raise InvalidModelConfigException("No valid config found")
248274

@@ -285,9 +311,6 @@ def cast_overrides(overrides: dict[str, Any]):
285311
@classmethod
286312
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
287313
"""Creates an instance of this config or raises InvalidModelConfigException."""
288-
if not cls.matches(mod):
289-
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
290-
291314
fields = cls.parse(mod)
292315
cls.cast_overrides(overrides)
293316
fields.update(overrides)
@@ -563,7 +586,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
563586

564587
@classmethod
565588
def matches(cls, mod: ModelOnDisk) -> bool:
566-
if mod.format_type == ModelFormat.Checkpoint:
589+
if mod.layout == FSLayout.FILE:
567590
return False
568591

569592
config_path = mod.path / "config.json"

Diff for: scripts/strip_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
7171
print(f"Created clone of {original.name} at {stripped.path}")
7272

7373
for component_path in stripped.component_paths():
74-
original_state_dict = ModelOnDisk.load_state_dict(component_path)
74+
original_state_dict = stripped.load_state_dict(component_path)
7575
stripped_state_dict = strip(original_state_dict) # type: ignore
7676
with open(component_path, "w") as f:
7777
json.dump(stripped_state_dict, f, indent=4)

0 commit comments

Comments
 (0)