|
9 | 9 | from typing import Any, Dict, List, Optional, Type
|
10 | 10 |
|
11 | 11 | from fastapi import Body, Path, Query, Response, UploadFile
|
12 |
| -from fastapi.responses import FileResponse |
| 12 | +from fastapi.responses import FileResponse, HTMLResponse |
13 | 13 | from fastapi.routing import APIRouter
|
14 | 14 | from PIL import Image
|
15 | 15 | from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
@@ -502,6 +502,133 @@ async def install_model(
|
502 | 502 | return result
|
503 | 503 |
|
504 | 504 |
|
| 505 | +@model_manager_router.get( |
| 506 | + "/install/huggingface", |
| 507 | + operation_id="install_hugging_face_model", |
| 508 | + responses={ |
| 509 | + 201: {"description": "The model is being installed"}, |
| 510 | + 400: {"description": "Bad request"}, |
| 511 | + 409: {"description": "There is already a model corresponding to this path or repo_id"}, |
| 512 | + }, |
| 513 | + status_code=201, |
| 514 | + response_class=HTMLResponse, |
| 515 | +) |
| 516 | +async def install_hugging_face_model( |
| 517 | + source: str = Query(description="HuggingFace repo_id to install"), |
| 518 | +) -> HTMLResponse: |
| 519 | + """Install a Hugging Face model using a string identifier.""" |
| 520 | + |
| 521 | + def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str: |
| 522 | + if message: |
| 523 | + message = f"<p>{message}</p>" |
| 524 | + title_class = "error" if is_error else "success" |
| 525 | + return f""" |
| 526 | + <html> |
| 527 | +
|
| 528 | + <head> |
| 529 | + <title>{title}</title> |
| 530 | + <style> |
| 531 | + body {{ |
| 532 | + text-align: center; |
| 533 | + background-color: hsl(220 12% 10% / 1); |
| 534 | + font-family: Helvetica, sans-serif; |
| 535 | + color: hsl(220 12% 86% / 1); |
| 536 | + }} |
| 537 | +
|
| 538 | + .repo-id {{ |
| 539 | + color: hsl(220 12% 68% / 1); |
| 540 | + }} |
| 541 | +
|
| 542 | + .error {{ |
| 543 | + color: hsl(0 42% 68% / 1) |
| 544 | + }} |
| 545 | +
|
| 546 | + .message-box {{ |
| 547 | + display: inline-block; |
| 548 | + border-radius: 5px; |
| 549 | + background-color: hsl(220 12% 20% / 1); |
| 550 | + padding-inline-end: 30px; |
| 551 | + padding: 20px; |
| 552 | + padding-inline-start: 30px; |
| 553 | + padding-inline-end: 30px; |
| 554 | + }} |
| 555 | +
|
| 556 | + .container {{ |
| 557 | + display: flex; |
| 558 | + width: 100%; |
| 559 | + height: 100%; |
| 560 | + align-items: center; |
| 561 | + justify-content: center; |
| 562 | + }} |
| 563 | +
|
| 564 | + a {{ |
| 565 | + color: inherit |
| 566 | + }} |
| 567 | +
|
| 568 | + a:visited {{ |
| 569 | + color: inherit |
| 570 | + }} |
| 571 | +
|
| 572 | + a:active {{ |
| 573 | + color: inherit |
| 574 | + }} |
| 575 | + </style> |
| 576 | + </head> |
| 577 | +
|
| 578 | + <body style="background-color: hsl(220 12% 10% / 1);"> |
| 579 | + <div class="container"> |
| 580 | + <div class="message-box"> |
| 581 | + <h2 class="{title_class}">{heading}</h2> |
| 582 | + {message} |
| 583 | + <p class="repo-id">Repo ID: {repo_id}</p> |
| 584 | + </div> |
| 585 | + </div> |
| 586 | + </body> |
| 587 | +
|
| 588 | + </html> |
| 589 | + """ |
| 590 | + |
| 591 | + try: |
| 592 | + metadata = HuggingFaceMetadataFetch().from_id(source) |
| 593 | + assert isinstance(metadata, ModelMetadataWithFiles) |
| 594 | + except UnknownMetadataException: |
| 595 | + title = "Unable to Install Model" |
| 596 | + heading = "No HuggingFace repository found with that repo ID." |
| 597 | + message = "Ensure the repo ID is correct and try again." |
| 598 | + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400) |
| 599 | + |
| 600 | + logger = ApiDependencies.invoker.services.logger |
| 601 | + |
| 602 | + try: |
| 603 | + installer = ApiDependencies.invoker.services.model_manager.install |
| 604 | + if metadata.is_diffusers: |
| 605 | + installer.heuristic_import( |
| 606 | + source=source, |
| 607 | + inplace=False, |
| 608 | + ) |
| 609 | + elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1: |
| 610 | + installer.heuristic_import( |
| 611 | + source=str(metadata.ckpt_urls[0]), |
| 612 | + inplace=False, |
| 613 | + ) |
| 614 | + else: |
| 615 | + title = "Unable to Install Model" |
| 616 | + heading = "This HuggingFace repo has multiple models." |
| 617 | + message = "Please use the Model Manager to install this model." |
| 618 | + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200) |
| 619 | + |
| 620 | + title = "Model Install Started" |
| 621 | + heading = "Your HuggingFace model is installing now." |
| 622 | + message = "You can close this tab and check the Model Manager for installation progress." |
| 623 | + return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201) |
| 624 | + except Exception as e: |
| 625 | + logger.error(str(e)) |
| 626 | + title = "Unable to Install Model" |
| 627 | + heading = "There was an problem installing this model." |
| 628 | + message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.' |
| 629 | + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500) |
| 630 | + |
| 631 | + |
505 | 632 | @model_manager_router.get(
|
506 | 633 | "/install",
|
507 | 634 | operation_id="list_model_installs",
|
|
0 commit comments