Skip to content

Commit 70e40fa

Browse files
authored
added route to install huggingface models from model marketplace (#6515)
## Summary added route to install huggingface models from model marketplace <!--A description of the changes in this PR. Include the kind of change (fix, feature, docs, etc), the "why" and the "how". Screenshots or videos are useful for frontend changes.--> ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions test by going to http://localhost:5173/api/v2/models/install/huggingface?source=${hfRepo} <!--WHEN APPLICABLE: Describe how we can test the changes in this PR.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents 785bb1d + e26125b commit 70e40fa

File tree

9 files changed

+582
-204
lines changed

9 files changed

+582
-204
lines changed

invokeai/app/api/routers/model_manager.py

+128-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Any, Dict, List, Optional, Type
1010

1111
from fastapi import Body, Path, Query, Response, UploadFile
12-
from fastapi.responses import FileResponse
12+
from fastapi.responses import FileResponse, HTMLResponse
1313
from fastapi.routing import APIRouter
1414
from PIL import Image
1515
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
@@ -502,6 +502,133 @@ async def install_model(
502502
return result
503503

504504

505+
@model_manager_router.get(
506+
"/install/huggingface",
507+
operation_id="install_hugging_face_model",
508+
responses={
509+
201: {"description": "The model is being installed"},
510+
400: {"description": "Bad request"},
511+
409: {"description": "There is already a model corresponding to this path or repo_id"},
512+
},
513+
status_code=201,
514+
response_class=HTMLResponse,
515+
)
516+
async def install_hugging_face_model(
517+
source: str = Query(description="HuggingFace repo_id to install"),
518+
) -> HTMLResponse:
519+
"""Install a Hugging Face model using a string identifier."""
520+
521+
def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
522+
if message:
523+
message = f"<p>{message}</p>"
524+
title_class = "error" if is_error else "success"
525+
return f"""
526+
<html>
527+
528+
<head>
529+
<title>{title}</title>
530+
<style>
531+
body {{
532+
text-align: center;
533+
background-color: hsl(220 12% 10% / 1);
534+
font-family: Helvetica, sans-serif;
535+
color: hsl(220 12% 86% / 1);
536+
}}
537+
538+
.repo-id {{
539+
color: hsl(220 12% 68% / 1);
540+
}}
541+
542+
.error {{
543+
color: hsl(0 42% 68% / 1)
544+
}}
545+
546+
.message-box {{
547+
display: inline-block;
548+
border-radius: 5px;
549+
background-color: hsl(220 12% 20% / 1);
550+
padding-inline-end: 30px;
551+
padding: 20px;
552+
padding-inline-start: 30px;
553+
padding-inline-end: 30px;
554+
}}
555+
556+
.container {{
557+
display: flex;
558+
width: 100%;
559+
height: 100%;
560+
align-items: center;
561+
justify-content: center;
562+
}}
563+
564+
a {{
565+
color: inherit
566+
}}
567+
568+
a:visited {{
569+
color: inherit
570+
}}
571+
572+
a:active {{
573+
color: inherit
574+
}}
575+
</style>
576+
</head>
577+
578+
<body style="background-color: hsl(220 12% 10% / 1);">
579+
<div class="container">
580+
<div class="message-box">
581+
<h2 class="{title_class}">{heading}</h2>
582+
{message}
583+
<p class="repo-id">Repo ID: {repo_id}</p>
584+
</div>
585+
</div>
586+
</body>
587+
588+
</html>
589+
"""
590+
591+
try:
592+
metadata = HuggingFaceMetadataFetch().from_id(source)
593+
assert isinstance(metadata, ModelMetadataWithFiles)
594+
except UnknownMetadataException:
595+
title = "Unable to Install Model"
596+
heading = "No HuggingFace repository found with that repo ID."
597+
message = "Ensure the repo ID is correct and try again."
598+
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
599+
600+
logger = ApiDependencies.invoker.services.logger
601+
602+
try:
603+
installer = ApiDependencies.invoker.services.model_manager.install
604+
if metadata.is_diffusers:
605+
installer.heuristic_import(
606+
source=source,
607+
inplace=False,
608+
)
609+
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
610+
installer.heuristic_import(
611+
source=str(metadata.ckpt_urls[0]),
612+
inplace=False,
613+
)
614+
else:
615+
title = "Unable to Install Model"
616+
heading = "This HuggingFace repo has multiple models."
617+
message = "Please use the Model Manager to install this model."
618+
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
619+
620+
title = "Model Install Started"
621+
heading = "Your HuggingFace model is installing now."
622+
message = "You can close this tab and check the Model Manager for installation progress."
623+
return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
624+
except Exception as e:
625+
logger.error(str(e))
626+
title = "Unable to Install Model"
627+
heading = "There was an problem installing this model."
628+
message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.'
629+
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
630+
631+
505632
@model_manager_router.get(
506633
"/install",
507634
operation_id="list_model_installs",

invokeai/app/services/events/events_base.py

+5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ModelInstallCompleteEvent,
2323
ModelInstallDownloadProgressEvent,
2424
ModelInstallDownloadsCompleteEvent,
25+
ModelInstallDownloadStartedEvent,
2526
ModelInstallErrorEvent,
2627
ModelInstallStartedEvent,
2728
ModelLoadCompleteEvent,
@@ -144,6 +145,10 @@ def emit_model_load_complete(
144145

145146
# region Model install
146147

148+
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
149+
"""Emitted at intervals while the install job is started (remote models only)."""
150+
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
151+
147152
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
148153
"""Emitted at intervals while the install job is in progress (remote models only)."""
149154
self.dispatch(ModelInstallDownloadProgressEvent.build(job))

invokeai/app/services/events/events_common.py

+36
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,42 @@ def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = N
417417
return cls(config=config, submodel_type=submodel_type)
418418

419419

420+
@payload_schema.register
421+
class ModelInstallDownloadStartedEvent(ModelEventBase):
422+
"""Event model for model_install_download_started"""
423+
424+
__event_name__ = "model_install_download_started"
425+
426+
id: int = Field(description="The ID of the install job")
427+
source: str = Field(description="Source of the model; local path, repo_id or url")
428+
local_path: str = Field(description="Where model is downloading to")
429+
bytes: int = Field(description="Number of bytes downloaded so far")
430+
total_bytes: int = Field(description="Total size of download, including all files")
431+
parts: list[dict[str, int | str]] = Field(
432+
description="Progress of downloading URLs that comprise the model, if any"
433+
)
434+
435+
@classmethod
436+
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
437+
parts: list[dict[str, str | int]] = [
438+
{
439+
"url": str(x.source),
440+
"local_path": str(x.download_path),
441+
"bytes": x.bytes,
442+
"total_bytes": x.total_bytes,
443+
}
444+
for x in job.download_parts
445+
]
446+
return cls(
447+
id=job.id,
448+
source=str(job.source),
449+
local_path=job.local_path.as_posix(),
450+
parts=parts,
451+
bytes=job.bytes,
452+
total_bytes=job.total_bytes,
453+
)
454+
455+
420456
@payload_schema.register
421457
class ModelInstallDownloadProgressEvent(ModelEventBase):
422458
"""Event model for model_install_download_progress"""

invokeai/app/services/model_install/model_install_default.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None
822822
install_job.download_parts = download_job.download_parts
823823
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
824824
install_job.total_bytes = download_job.total_bytes
825-
self._signal_job_downloading(install_job)
825+
self._signal_job_download_started(install_job)
826826

827827
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
828828
with self._lock:
@@ -874,6 +874,13 @@ def _signal_job_running(self, job: ModelInstallJob) -> None:
874874
if self._event_bus:
875875
self._event_bus.emit_model_install_started(job)
876876

877+
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
878+
if self._event_bus:
879+
assert job._multifile_job is not None
880+
assert job.bytes is not None
881+
assert job.total_bytes is not None
882+
self._event_bus.emit_model_install_download_started(job)
883+
877884
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
878885
if self._event_bus:
879886
assert job._multifile_job is not None

0 commit comments

Comments
 (0)