Skip to content

Commit bf45a3c

Browse files
committed
Formatting
1 parent c365bbd commit bf45a3c

File tree

3 files changed

+13
-34
lines changed

3 files changed

+13
-34
lines changed

Diff for: invokeai/app/services/model_install/model_install_default.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -647,14 +647,12 @@ def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
647647
config = config or ModelRecordChanges()
648648
hash_algo = self._app_config.hashing_algorithm
649649
fields = config.model_dump()
650-
overrides = { "hash_algo": hash_algo, **fields}
650+
overrides = {"hash_algo": hash_algo, **fields}
651651

652652
try:
653653
return ModelConfigBase.classify(model_path, **overrides)
654654
except InvalidModelConfigException:
655-
return ModelProbe.probe(
656-
model_path=model_path, fields=fields, hash_algo=hash_algo
657-
) # type: ignore
655+
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
658656

659657
def _register(
660658
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

Diff for: invokeai/backend/model_manager/config.py

+3-27
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ class ControlAdapterDefaultSettings(BaseModel):
202202

203203
class ModelOnDisk:
204204
"""A utility class representing a model stored on disk."""
205+
205206
def __init__(self, path: Path):
206207
self.path = path
207208
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
@@ -279,7 +280,7 @@ def __init_subclass__(cls, **kwargs):
279280
@staticmethod
280281
def all_config_classes():
281282
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
282-
concrete = { cls for cls in subclasses if not isabstract(cls) }
283+
concrete = {cls for cls in subclasses if not isabstract(cls)}
283284
return concrete
284285

285286
@staticmethod
@@ -332,7 +333,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
332333

333334
fields = cls.parse(mod)
334335

335-
fields["path"] = mod.path.as_posix()
336+
fields["path"] = mod.path.as_posix()
336337
fields["source"] = fields.get("source") or fields["path"]
337338
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
338339
fields["name"] = mod.name
@@ -388,17 +389,14 @@ class T5EncoderConfigBase(ABC, BaseModel):
388389
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
389390

390391

391-
392392
class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
393393
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
394394

395395

396-
397396
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
398397
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
399398

400399

401-
402400
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
403401
"""Model config for LoRA/Lycoris models."""
404402

@@ -411,7 +409,6 @@ class ControlAdapterConfigBase(ABC, BaseModel):
411409
)
412410

413411

414-
415412
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
416413
"""Model config for Control LoRA models."""
417414

@@ -420,7 +417,6 @@ class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, Model
420417
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
421418

422419

423-
424420
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
425421
"""Model config for Control LoRA models."""
426422

@@ -429,52 +425,45 @@ class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, Mod
429425
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
430426

431427

432-
433428
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
434429
"""Model config for LoRA/Diffusers models."""
435430

436431
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
437432

438433

439-
440434
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
441435
"""Model config for standalone VAE models."""
442436

443437
type: Literal[ModelType.VAE] = ModelType.VAE
444438

445439

446-
447440
class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase):
448441
"""Model config for standalone VAE models (diffusers version)."""
449442

450443
type: Literal[ModelType.VAE] = ModelType.VAE
451444
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
452445

453446

454-
455447
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
456448
"""Model config for ControlNet models (diffusers version)."""
457449

458450
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
459451
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
460452

461453

462-
463454
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
464455
"""Model config for ControlNet models (diffusers version)."""
465456

466457
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
467458

468459

469-
470460
class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase):
471461
"""Model config for textual inversion embeddings."""
472462

473463
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
474464
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
475465

476466

477-
478467
class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase):
479468
"""Model config for textual inversion embeddings."""
480469

@@ -491,15 +480,13 @@ class MainConfigBase(ABC, BaseModel):
491480
variant: AnyVariant = ModelVariantType.Normal
492481

493482

494-
495483
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
496484
"""Model config for main checkpoint models."""
497485

498486
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
499487
upcast_attention: bool = False
500488

501489

502-
503490
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
504491
"""Model config for main checkpoint models."""
505492

@@ -508,7 +495,6 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, L
508495
upcast_attention: bool = False
509496

510497

511-
512498
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
513499
"""Model config for main checkpoint models."""
514500

@@ -517,7 +503,6 @@ class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbe
517503
upcast_attention: bool = False
518504

519505

520-
521506
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
522507
"""Model config for main diffusers models."""
523508

@@ -528,7 +513,6 @@ class IPAdapterConfigBase(ABC, BaseModel):
528513
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
529514

530515

531-
532516
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
533517
"""Model config for IP Adapter diffusers format models."""
534518

@@ -538,7 +522,6 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfig
538522
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
539523

540524

541-
542525
class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
543526
"""Model config for IP Adapter checkpoint format models."""
544527

@@ -553,7 +536,6 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
553536
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
554537

555538

556-
557539
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
558540
"""Model config for CLIP-G Embeddings."""
559541

@@ -564,7 +546,6 @@ def get_tag(cls) -> Tag:
564546
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
565547

566548

567-
568549
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
569550
"""Model config for CLIP-L Embeddings."""
570551

@@ -575,23 +556,20 @@ def get_tag(cls) -> Tag:
575556
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
576557

577558

578-
579559
class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
580560
"""Model config for CLIPVision."""
581561

582562
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
583563
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
584564

585565

586-
587566
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
588567
"""Model config for T2I."""
589568

590569
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
591570
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
592571

593572

594-
595573
class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
596574
"""Model config for Spandrel Image to Image models."""
597575

@@ -601,15 +579,13 @@ class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
601579
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
602580

603581

604-
605582
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
606583
"""Model config for SigLIP."""
607584

608585
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
609586
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
610587

611588

612-
613589
class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
614590
"""Model config for FLUX Tools Redux model."""
615591

Diff for: tests/test_model_probe.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,16 @@ def test_serialisation_roundtrip():
177177
for config_cls in ModelConfigBase.all_config_classes() - excluded:
178178
trials_per_class = 50
179179

180-
factory_args = { "__use_defaults__": True, "__random_seed__": 1234, "__check_model__": True, }
180+
factory_args = {
181+
"__use_defaults__": True,
182+
"__random_seed__": 1234,
183+
"__check_model__": True,
184+
}
181185
factory = ModelFactory.create_factory(config_cls, **factory_args)
182186

183-
configs_with_random_data = [factory.build() for _ in range(trials_per_class)] #mocker.mock(config_cls, trials_per_class)
187+
configs_with_random_data = [
188+
factory.build() for _ in range(trials_per_class)
189+
] # mocker.mock(config_cls, trials_per_class)
184190

185191
for config in configs_with_random_data:
186192
as_json = config.model_dump_json()
@@ -204,4 +210,3 @@ def test_inheritance_order():
204210
excluded = {abc.ABC, pydantic.BaseModel, object}
205211
inheritance_list = [cls for cls in config_cls.mro() if cls not in excluded]
206212
assert inheritance_list[-1] is ModelConfigBase
207-

0 commit comments

Comments
 (0)