Skip to content

Commit 9432336

Browse files
Add simplified model manager install API to InvocationContext (#6132)
## Summary This three two model manager-related methods to the InvocationContext uniform API. They are accessible via `context.models.*`: 1. **`load_local_model(model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None) -> LoadedModelWithoutConfig`** *Load the model located at the indicated path.* This will load a local model (.safetensors, .ckpt or diffusers directory) into the model manager RAM cache and return its `LoadedModelWithoutConfig`. If the optional loader argument is provided, the loader will be invoked to load the model into memory. Otherwise the method will call `safetensors.torch.load_file()` `torch.load()` (with a pickle scan), or `from_pretrained()` as appropriate to the path type. Be aware that the `LoadedModelWithoutConfig` object differs from `LoadedModel` by having no `config` attribute. Here is an example of usage: ``` def invoke(self, context: InvocatinContext) -> ImageOutput: model_path = Path('/opt/models/RealESRGAN_x4plus.pth') loadnet = context.models.load_local_model(model_path) with loadnet as loadnet_model: upscaler = RealESRGAN(loadnet=loadnet_model,...) ``` --- 2. **`load_remote_model(source: str | AnyHttpUrl, loader: Optional[Callable[[Path], AnyModel]] = None) -> LoadedModelWithoutConfig`** *Load the model located at the indicated URL or repo_id.* This is similar to `load_local_model()` but it accepts either a HugginFace repo_id (as a string), or a URL. The model's file(s) will be downloaded to `models/.download_cache` and then loaded, returning a ``` def invoke(self, context: InvocatinContext) -> ImageOutput: model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' loadnet = context.models.load_remote_model(model_url) with loadnet as loadnet_model: upscaler = RealESRGAN(loadnet=loadnet_model,...) ``` --- 3. **`download_and_cache_model( source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: Optional[int] = 0) -> Path`** Download the model file located at source to the models cache and return its Path. This will check `models/.download_cache` for the desired model file and download it from the indicated source if not already present. The local Path to the downloaded file is then returned. --- ## Other Changes This PR performs a migration, in which it renames `models/.cache` to `models/.convert_cache`, and migrates previously-downloaded ESRGAN, openpose, DepthAnything and Lama inpaint models from the `models/core` directory into `models/.download_cache`. There are a number of legacy model files in `models/core`, such as GFPGAN, which are no longer used. This PR deletes them and tidies up the `models/core` directory. ## Related Issues / Discussions I have systematically replaced all the calls to `download_with_progress_bar()`. This function is no longer used elsewhere and has been removed. <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions I have added unit tests for the three new calls. You may test that the `load_and_cache_model()` call is working by running the upscaler within the web app. On first try, you will see the model file being downloaded into the models `.cache` directory. On subsequent tries, the model will either load from RAM (if it hasn't been displaced) or will be loaded from the filesystem. <!--WHEN APPLICABLE: Describe how we can test the changes in this PR.--> ## Merge Plan Squash merge when approved. <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [X] _The PR has a short but descriptive title, suitable for a changelog_ - [X] _Tests added / updated (if applicable)_ - [X] _Documentation added / updated (if applicable)_
2 parents 0dbec3a + 7d19af2 commit 9432336

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1517
-652
lines changed

docs/contributing/DOWNLOAD_QUEUE.md

+60-3
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
128128
specify the source and destination of the download, and keep track of
129129
the progress of the download.
130130

131-
The only job type currently implemented is `DownloadJob`, a pydantic object with the
131+
Two job types are defined. `DownloadJob` and
132+
`MultiFileDownloadJob`. The former is a pydantic object with the
132133
following fields:
133134

134135
| **Field** | **Type** | **Default** | **Description** |
@@ -138,7 +139,7 @@ following fields:
138139
| `dest` | Path | | Where to download to |
139140
| `access_token` | str | | [optional] string containing authentication token for access |
140141
| `on_start` | Callable | | [optional] callback when the download starts |
141-
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
142+
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
142143
| `on_complete` | Callable | | [optional] callback called after successful download completion |
143144
| `on_error` | Callable | | [optional] callback called after an error occurs |
144145
| `id` | int | auto assigned | Job ID, an integer >= 0 |
@@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
190191
`error_type` field of "DownloadJobCancelledException". In addition,
191192
the job's `cancelled` property will be set to True.
192193

194+
The `MultiFileDownloadJob` is used for diffusers model downloads,
195+
which contain multiple files and directories under a common root:
196+
197+
| **Field** | **Type** | **Default** | **Description** |
198+
|----------------|-----------------|---------------|-----------------|
199+
| _Fields passed in at job creation time_ |
200+
| `download_parts` | Set[DownloadJob]| | Component download jobs |
201+
| `dest` | Path | | Where to download to |
202+
| `on_start` | Callable | | [optional] callback when the download starts |
203+
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
204+
| `on_complete` | Callable | | [optional] callback called after successful download completion |
205+
| `on_error` | Callable | | [optional] callback called after an error occurs |
206+
| `id` | int | auto assigned | Job ID, an integer >= 0 |
207+
| _Fields updated over the course of the download task_
208+
| `status` | DownloadJobStatus| | Status code |
209+
| `download_path` | Path | | Path to the root of the downloaded files |
210+
| `bytes` | int | 0 | Bytes downloaded so far |
211+
| `total_bytes` | int | 0 | Total size of the file at the remote site |
212+
| `error_type` | str | | String version of the exception that caused an error during download |
213+
| `error` | str | | String version of the traceback associated with an error |
214+
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
215+
216+
Note that the MultiFileDownloadJob does not support the `priority`,
217+
`job_started`, `job_ended` or `content_type` attributes. You can get
218+
these from the individual download jobs in `download_parts`.
219+
220+
193221
### Callbacks
194222

195223
Download jobs can be associated with a series of callbacks, each with
@@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
251279
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
252280
with `join()`.
253281

254-
#### job = queue.download(source, dest, priority, access_token)
282+
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
255283

256284
Create a new download job and put it on the queue, returning the
257285
DownloadJob object.
258286

287+
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
288+
289+
This is similar to download(), but instead of taking a single source,
290+
it accepts a `parts` argument consisting of a list of
291+
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
292+
where the URL is the location of the remote file, and the Path is the
293+
destination.
294+
295+
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
296+
consists of a url/path pair. Note that the path *must* be relative.
297+
298+
The method returns a `MultiFileDownloadJob`.
299+
300+
301+
```
302+
from invokeai.backend.model_manager.metadata import RemoteModelFile
303+
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
304+
path='my_model/textencoder/pytorch_model.safetensors'
305+
)
306+
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
307+
path='my_model/vae/diffusers_model.safetensors'
308+
)
309+
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
310+
dest='/tmp/downloads',
311+
on_progress=TqdmProgress().update)
312+
queue.wait_for_job(job)
313+
print(f"The files were downloaded to {job.download_path}")
314+
```
315+
259316
#### jobs = queue.list_jobs()
260317

261318
Return a list of all active and inactive `DownloadJob`s.

docs/contributing/MODEL_MANAGER.md

+63-8
Original file line numberDiff line numberDiff line change
@@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
397397
following initialization pattern:
398398

399399
```
400-
from invokeai.app.services.config import InvokeAIAppConfig
400+
from invokeai.app.services.config import get_config
401401
from invokeai.app.services.model_records import ModelRecordServiceSQL
402402
from invokeai.app.services.model_install import ModelInstallService
403403
from invokeai.app.services.download import DownloadQueueService
404-
from invokeai.app.services.shared.sqlite import SqliteDatabase
404+
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
405405
from invokeai.backend.util.logging import InvokeAILogger
406406
407-
config = InvokeAIAppConfig.get_config()
408-
config.parse_args()
407+
config = get_config()
409408
410409
logger = InvokeAILogger.get_logger(config=config)
411-
db = SqliteDatabase(config, logger)
410+
db = SqliteDatabase(config.db_path, logger)
412411
record_store = ModelRecordServiceSQL(db)
413412
queue = DownloadQueueService()
414413
queue.start()
415414
416-
installer = ModelInstallService(app_config=config,
415+
installer = ModelInstallService(app_config=config,
417416
record_store=record_store,
418-
download_queue=queue
419-
)
417+
download_queue=queue
418+
)
420419
installer.start()
421420
```
422421

@@ -1602,3 +1601,59 @@ This method takes a model key, looks it up using the
16021601
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
16031602
model configuration to `load_model_by_config()`. It may raise a
16041603
`NotImplementedException`.
1604+
1605+
## Invocation Context Model Manager API
1606+
1607+
Within invocations, the following methods are available from the
1608+
`InvocationContext` object:
1609+
1610+
### context.download_and_cache_model(source) -> Path
1611+
1612+
This method accepts a `source` of a remote model, downloads and caches
1613+
it locally, and then returns a Path to the local model. The source can
1614+
be a direct download URL or a HuggingFace repo_id.
1615+
1616+
In the case of HuggingFace repo_id, the following variants are
1617+
recognized:
1618+
1619+
* stabilityai/stable-diffusion-v4 -- default model
1620+
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
1621+
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
1622+
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
1623+
1624+
You can also point at an arbitrary individual file within a repo_id
1625+
directory using this syntax:
1626+
1627+
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
1628+
1629+
### context.load_local_model(model_path, [loader]) -> LoadedModel
1630+
1631+
This method loads a local model from the indicated path, returning a
1632+
`LoadedModel`. The optional loader is a Callable that accepts a Path
1633+
to the object, and returns a `AnyModel` object. If no loader is
1634+
provided, then the method will use `torch.load()` for a .ckpt or .bin
1635+
checkpoint file, `safetensors.torch.load_file()` for a safetensors
1636+
checkpoint file, or `cls.from_pretrained()` for a directory that looks
1637+
like a diffusers directory.
1638+
1639+
### context.load_remote_model(source, [loader]) -> LoadedModel
1640+
1641+
This method accepts a `source` of a remote model, downloads and caches
1642+
it locally, loads it, and returns a `LoadedModel`. The source can be a
1643+
direct download URL or a HuggingFace repo_id.
1644+
1645+
In the case of HuggingFace repo_id, the following variants are
1646+
recognized:
1647+
1648+
* stabilityai/stable-diffusion-v4 -- default model
1649+
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
1650+
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
1651+
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
1652+
1653+
You can also point at an arbitrary individual file within a repo_id
1654+
directory using this syntax:
1655+
1656+
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
1657+
1658+
1659+

invokeai/app/api/dependencies.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger
9393
conditioning = ObjectSerializerForwardCache(
9494
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
9595
)
96-
download_queue_service = DownloadQueueService(event_bus=events)
96+
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
9797
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
9898
model_manager = ModelManagerService.build_model_manager(
9999
app_config=configuration,

invokeai/app/invocations/controlnet_image_processors.py

+34-21
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# initial implementation by Gregg Helt, 2023
33
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
44
from builtins import bool, float
5+
from pathlib import Path
56
from typing import Dict, List, Literal, Union
67

78
import cv2
@@ -36,12 +37,13 @@
3637
from invokeai.app.services.shared.invocation_context import InvocationContext
3738
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
3839
from invokeai.backend.image_util.canny import get_canny_edges
39-
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
40-
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
40+
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
41+
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
4142
from invokeai.backend.image_util.hed import HEDProcessor
4243
from invokeai.backend.image_util.lineart import LineartProcessor
4344
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
4445
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
46+
from invokeai.backend.util.devices import TorchDevice
4547

4648
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
4749

@@ -139,6 +141,7 @@ def load_image(self, context: InvocationContext) -> Image.Image:
139141
return context.images.get_pil(self.image.image_name, "RGB")
140142

141143
def invoke(self, context: InvocationContext) -> ImageOutput:
144+
self._context = context
142145
raw_image = self.load_image(context)
143146
# image type should be PIL.PngImagePlugin.PngImageFile ?
144147
processed_image = self.run_processor(raw_image)
@@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
284287
# depth_and_normal not supported in controlnet_aux v0.0.3
285288
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
286289

287-
def run_processor(self, image):
290+
def run_processor(self, image: Image.Image) -> Image.Image:
291+
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
288292
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
289293
processed_image = midas_processor(
290294
image,
@@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
311315
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
312316
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
313317

314-
def run_processor(self, image):
318+
def run_processor(self, image: Image.Image) -> Image.Image:
315319
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
316320
processed_image = normalbae_processor(
317321
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
@@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
330334
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
331335
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
332336

333-
def run_processor(self, image):
337+
def run_processor(self, image: Image.Image) -> Image.Image:
334338
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
335339
processed_image = mlsd_processor(
336340
image,
@@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
353357
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
354358
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
355359

356-
def run_processor(self, image):
360+
def run_processor(self, image: Image.Image) -> Image.Image:
357361
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
358362
processed_image = pidi_processor(
359363
image,
@@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
381385
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
382386
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
383387

384-
def run_processor(self, image):
388+
def run_processor(self, image: Image.Image) -> Image.Image:
385389
content_shuffle_processor = ContentShuffleDetector()
386390
processed_image = content_shuffle_processor(
387391
image,
@@ -405,7 +409,7 @@ def run_processor(self, image):
405409
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
406410
"""Applies Zoe depth processing to image"""
407411

408-
def run_processor(self, image):
412+
def run_processor(self, image: Image.Image) -> Image.Image:
409413
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
410414
processed_image = zoe_depth_processor(image)
411415
return processed_image
@@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
426430
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
427431
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
428432

429-
def run_processor(self, image):
433+
def run_processor(self, image: Image.Image) -> Image.Image:
430434
mediapipe_face_processor = MediapipeFaceDetector()
431435
processed_image = mediapipe_face_processor(
432436
image,
@@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
454458
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
455459
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
456460

457-
def run_processor(self, image):
461+
def run_processor(self, image: Image.Image) -> Image.Image:
458462
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
459463
processed_image = leres_processor(
460464
image,
@@ -496,8 +500,8 @@ def tile_resample(
496500
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
497501
return np_img
498502

499-
def run_processor(self, img):
500-
np_img = np.array(img, dtype=np.uint8)
503+
def run_processor(self, image: Image.Image) -> Image.Image:
504+
np_img = np.array(image, dtype=np.uint8)
501505
processed_np_image = self.tile_resample(
502506
np_img,
503507
# res=self.tile_size,
@@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
520524
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
521525
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
522526

523-
def run_processor(self, image):
527+
def run_processor(self, image: Image.Image) -> Image.Image:
524528
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
525529
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
526530
"ybelkada/segment-anything", subfolder="checkpoints"
@@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
566570

567571
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
568572

569-
def run_processor(self, image: Image.Image):
573+
def run_processor(self, image: Image.Image) -> Image.Image:
570574
np_image = np.array(image, dtype=np.uint8)
571575
height, width = np_image.shape[:2]
572576

@@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
601605
)
602606
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
603607

604-
def run_processor(self, image: Image.Image):
605-
depth_anything_detector = DepthAnythingDetector()
606-
depth_anything_detector.load_model(model_size=self.model_size)
608+
def run_processor(self, image: Image.Image) -> Image.Image:
609+
def loader(model_path: Path):
610+
return DepthAnythingDetector.load_model(
611+
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
612+
)
607613

608-
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
609-
return processed_image
614+
with self._context.models.load_remote_model(
615+
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
616+
) as model:
617+
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
618+
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
619+
return processed_image
610620

611621

612622
@invocation(
@@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
624634
draw_hands: bool = InputField(default=False)
625635
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
626636

627-
def run_processor(self, image: Image.Image):
628-
dw_openpose = DWOpenposeDetector()
637+
def run_processor(self, image: Image.Image) -> Image.Image:
638+
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
639+
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
640+
641+
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
629642
processed_image = dw_openpose(
630643
image,
631644
draw_face=self.draw_face,

0 commit comments

Comments
 (0)