Skip to content

Commit 54d6b6f

Browse files
fix: minor fixes to types in the DA Detector
1 parent bf18459 commit 54d6b6f

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

invokeai/backend/image_util/depth_anything/__init__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from PIL import Image
1010
from torchvision.transforms import Compose
1111

12+
from build.lib.invokeai.backend.model_management.models.base import ModelNotFoundException
1213
from invokeai.app.services.config.config_default import InvokeAIAppConfig
1314
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
1415
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
@@ -17,6 +18,8 @@
1718

1819
config = InvokeAIAppConfig.get_config()
1920

21+
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
22+
2023
DEPTH_ANYTHING_MODELS = {
2124
"large": {
2225
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
@@ -53,9 +56,9 @@
5356
class DepthAnythingDetector:
5457
def __init__(self) -> None:
5558
self.model = None
56-
self.model_size: Union[Literal["large", "base", "small"], None] = None
59+
self.model_size: Union[DEPTH_ANYTHING_MODEL_SIZES, None] = None
5760

58-
def load_model(self, model_size=Literal["large", "base", "small"]):
61+
def load_model(self, model_size: DEPTH_ANYTHING_MODEL_SIZES = "small"):
5962
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
6063
if not DEPTH_ANYTHING_MODEL_PATH.exists():
6164
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
@@ -84,16 +87,19 @@ def to(self, device):
8487
self.model.to(device)
8588
return self
8689

87-
def __call__(self, image, resolution=512):
88-
image = np.array(image, dtype=np.uint8)
89-
image = image[:, :, ::-1] / 255.0
90+
def __call__(self, image: Image.Image, resolution: int = 512):
91+
if self.model is None:
92+
raise ModelNotFoundException("Depth Anything Model not loaded")
93+
94+
np_image = np.array(image, dtype=np.uint8)
95+
np_image = np_image[:, :, ::-1] / 255.0
9096

91-
image_height, image_width = image.shape[:2]
92-
image = transform({"image": image})["image"]
93-
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
97+
image_height, image_width = np_image.shape[:2]
98+
np_image = transform({"image": image})["image"]
99+
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
94100

95101
with torch.no_grad():
96-
depth = self.model(image)
102+
depth = self.model(tensor_image)
97103
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
98104
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
99105

0 commit comments

Comments
 (0)