|
17 | 17 |
|
18 | 18 | config = InvokeAIAppConfig.get_config()
|
19 | 19 |
|
| 20 | +DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"] |
| 21 | + |
20 | 22 | DEPTH_ANYTHING_MODELS = {
|
21 | 23 | "large": {
|
22 | 24 | "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
|
53 | 55 | class DepthAnythingDetector:
|
54 | 56 | def __init__(self) -> None:
|
55 | 57 | self.model = None
|
56 |
| - self.model_size: Union[Literal["large", "base", "small"], None] = None |
| 58 | + self.model_size: Union[DEPTH_ANYTHING_MODEL_SIZES, None] = None |
57 | 59 |
|
58 |
| - def load_model(self, model_size=Literal["large", "base", "small"]): |
| 60 | + def load_model(self, model_size: DEPTH_ANYTHING_MODEL_SIZES = "small"): |
59 | 61 | DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
|
60 | 62 | if not DEPTH_ANYTHING_MODEL_PATH.exists():
|
61 | 63 | download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
|
@@ -84,16 +86,19 @@ def to(self, device):
|
84 | 86 | self.model.to(device)
|
85 | 87 | return self
|
86 | 88 |
|
87 |
| - def __call__(self, image, resolution=512): |
88 |
| - image = np.array(image, dtype=np.uint8) |
89 |
| - image = image[:, :, ::-1] / 255.0 |
| 89 | + def __call__(self, image: Image.Image, resolution: int = 512): |
| 90 | + if self.model is None: |
| 91 | + raise Exception("Depth Anything Model not loaded") |
| 92 | + |
| 93 | + np_image = np.array(image, dtype=np.uint8) |
| 94 | + np_image = np_image[:, :, ::-1] / 255.0 |
90 | 95 |
|
91 |
| - image_height, image_width = image.shape[:2] |
92 |
| - image = transform({"image": image})["image"] |
93 |
| - image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device()) |
| 96 | + image_height, image_width = np_image.shape[:2] |
| 97 | + np_image = transform({"image": image})["image"] |
| 98 | + tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device()) |
94 | 99 |
|
95 | 100 | with torch.no_grad():
|
96 |
| - depth = self.model(image) |
| 101 | + depth = self.model(tensor_image) |
97 | 102 | depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
98 | 103 | depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
99 | 104 |
|
|
0 commit comments