Skip to content

Commit 32e1992

Browse files
shapovalovfacebook-github-bot
authored andcommitted
SQL Index Dataset
Summary: Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay. It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS. Reviewed By: bottler Differential Revision: D45086611 fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef
1 parent 7aeedd1 commit 32e1992

File tree

10 files changed

+2309
-6
lines changed

10 files changed

+2309
-6
lines changed

pytorch3d/implicitron/dataset/frame_data.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def build(
450450
self,
451451
frame_annotation: types.FrameAnnotation,
452452
sequence_annotation: types.SequenceAnnotation,
453+
load_blobs: bool = True,
453454
) -> FrameDataSubtype:
454455
"""An abstract method to build the frame data based on raw frame/sequence
455456
annotations, load the binary data and adjust them according to the metadata.
@@ -465,8 +466,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
465466
Beware that modifications of frame data are done in-place.
466467
467468
Args:
468-
dataset_root: The root folder of the dataset; all the paths in jsons are
469-
specified relative to this root (but not json paths themselves).
469+
dataset_root: The root folder of the dataset; all paths in frame / sequence
470+
annotations are defined w.r.t. this root. Has to be set if any of the
471+
load_* flabs below is true.
470472
load_images: Enable loading the frame RGB data.
471473
load_depths: Enable loading the frame depth maps.
472474
load_depth_masks: Enable loading the frame depth map masks denoting the
@@ -494,7 +496,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
494496
path_manager: Optionally a PathManager for interpreting paths in a special way.
495497
"""
496498

497-
dataset_root: str = ""
499+
dataset_root: Optional[str] = None
498500
load_images: bool = True
499501
load_depths: bool = True
500502
load_depth_masks: bool = True
@@ -510,6 +512,25 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
510512
box_crop_context: float = 0.3
511513
path_manager: Any = None
512514

515+
def __post_init__(self) -> None:
516+
load_any_blob = (
517+
self.load_images
518+
or self.load_depths
519+
or self.load_depth_masks
520+
or self.load_masks
521+
or self.load_point_clouds
522+
)
523+
if load_any_blob and self.dataset_root is None:
524+
raise ValueError(
525+
"dataset_root must be set to load any blob data. "
526+
"Make sure it is set in either FrameDataBuilder or Dataset params."
527+
)
528+
529+
if load_any_blob and not os.path.isdir(self.dataset_root): # pyre-ignore
530+
raise ValueError(
531+
f"dataset_root is passed but {self.dataset_root} does not exist."
532+
)
533+
513534
def build(
514535
self,
515536
frame_annotation: types.FrameAnnotation,
@@ -567,7 +588,7 @@ def build(
567588
if bbox_xywh is None and fg_mask_np is not None:
568589
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
569590

570-
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
591+
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
571592

572593
if frame_annotation.image is not None:
573594
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
@@ -612,7 +633,8 @@ def build(
612633
def _load_fg_probability(
613634
self, entry: types.FrameAnnotation
614635
) -> Tuple[np.ndarray, str]:
615-
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
636+
assert self.dataset_root is not None and entry.mask is not None
637+
full_path = os.path.join(self.dataset_root, entry.mask.path)
616638
fg_probability = load_mask(self._local_path(full_path))
617639
if fg_probability.shape[-2:] != entry.image.size:
618640
raise ValueError(
@@ -647,7 +669,7 @@ def _load_mask_depth(
647669
fg_probability: Optional[torch.Tensor],
648670
) -> Tuple[torch.Tensor, str, torch.Tensor]:
649671
entry_depth = entry.depth
650-
assert entry_depth is not None
672+
assert self.dataset_root is not None and entry_depth is not None
651673
path = os.path.join(self.dataset_root, entry_depth.path)
652674
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
653675

@@ -657,6 +679,7 @@ def _load_mask_depth(
657679

658680
if self.load_depth_masks:
659681
assert entry_depth.mask_path is not None
682+
# pyre-ignore
660683
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
661684
depth_mask = load_depth_mask(self._local_path(mask_path))
662685
else:
@@ -705,6 +728,7 @@ def _fix_point_cloud_path(self, path: str) -> str:
705728
)
706729
if path.startswith(unwanted_prefix):
707730
path = path[len(unwanted_prefix) :]
731+
assert self.dataset_root is not None
708732
return os.path.join(self.dataset_root, path)
709733

710734
def _local_path(self, path: str) -> str:
+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This functionality requires SQLAlchemy 2.0 or later.
8+
9+
import math
10+
import struct
11+
from typing import Optional, Tuple
12+
13+
import numpy as np
14+
15+
from pytorch3d.implicitron.dataset.types import (
16+
DepthAnnotation,
17+
ImageAnnotation,
18+
MaskAnnotation,
19+
PointCloudAnnotation,
20+
VideoAnnotation,
21+
ViewpointAnnotation,
22+
)
23+
24+
from sqlalchemy import LargeBinary
25+
from sqlalchemy.orm import (
26+
composite,
27+
DeclarativeBase,
28+
Mapped,
29+
mapped_column,
30+
MappedAsDataclass,
31+
)
32+
from sqlalchemy.types import TypeDecorator
33+
34+
35+
# these produce policies to serialize structured types to blobs
36+
def ArrayTypeFactory(shape):
37+
class NumpyArrayType(TypeDecorator):
38+
impl = LargeBinary
39+
40+
def process_bind_param(self, value, dialect):
41+
if value is not None:
42+
if value.shape != shape:
43+
raise ValueError(f"Passed an array of wrong shape: {value.shape}")
44+
return value.astype(np.float32).tobytes()
45+
return None
46+
47+
def process_result_value(self, value, dialect):
48+
if value is not None:
49+
return np.frombuffer(value, dtype=np.float32).reshape(shape)
50+
return None
51+
52+
return NumpyArrayType
53+
54+
55+
def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
56+
format_symbol = {
57+
float: "f", # float32
58+
int: "i", # int32
59+
}[dtype]
60+
61+
class TupleType(TypeDecorator):
62+
impl = LargeBinary
63+
_format = format_symbol * math.prod(shape)
64+
65+
def process_bind_param(self, value, _):
66+
if value is None:
67+
return None
68+
69+
if len(shape) > 1:
70+
value = np.array(value, dtype=dtype).reshape(-1)
71+
72+
return struct.pack(TupleType._format, *value)
73+
74+
def process_result_value(self, value, _):
75+
if value is None:
76+
return None
77+
78+
loaded = struct.unpack(TupleType._format, value)
79+
if len(shape) > 1:
80+
loaded = _rec_totuple(
81+
np.array(loaded, dtype=dtype).reshape(shape).tolist()
82+
)
83+
84+
return loaded
85+
86+
return TupleType
87+
88+
89+
def _rec_totuple(t):
90+
if isinstance(t, list):
91+
return tuple(_rec_totuple(x) for x in t)
92+
93+
return t
94+
95+
96+
class Base(MappedAsDataclass, DeclarativeBase):
97+
"""subclasses will be converted to dataclasses"""
98+
99+
100+
class SqlFrameAnnotation(Base):
101+
__tablename__ = "frame_annots"
102+
103+
sequence_name: Mapped[str] = mapped_column(primary_key=True)
104+
frame_number: Mapped[int] = mapped_column(primary_key=True)
105+
frame_timestamp: Mapped[float] = mapped_column(index=True)
106+
107+
image: Mapped[ImageAnnotation] = composite(
108+
mapped_column("_image_path"),
109+
mapped_column("_image_size", TupleTypeFactory(int)),
110+
)
111+
112+
depth: Mapped[DepthAnnotation] = composite(
113+
mapped_column("_depth_path", nullable=True),
114+
mapped_column("_depth_scale_adjustment", nullable=True),
115+
mapped_column("_depth_mask_path", nullable=True),
116+
)
117+
118+
mask: Mapped[MaskAnnotation] = composite(
119+
mapped_column("_mask_path", nullable=True),
120+
mapped_column("_mask_mass", index=True, nullable=True),
121+
mapped_column(
122+
"_mask_bounding_box_xywh",
123+
TupleTypeFactory(float, shape=(4,)),
124+
nullable=True,
125+
),
126+
)
127+
128+
viewpoint: Mapped[ViewpointAnnotation] = composite(
129+
mapped_column(
130+
"_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True
131+
),
132+
mapped_column(
133+
"_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True
134+
),
135+
mapped_column(
136+
"_viewpoint_focal_length", TupleTypeFactory(float), nullable=True
137+
),
138+
mapped_column(
139+
"_viewpoint_principal_point", TupleTypeFactory(float), nullable=True
140+
),
141+
mapped_column("_viewpoint_intrinsics_format", nullable=True),
142+
)
143+
144+
145+
class SqlSequenceAnnotation(Base):
146+
__tablename__ = "sequence_annots"
147+
148+
sequence_name: Mapped[str] = mapped_column(primary_key=True)
149+
category: Mapped[str] = mapped_column(index=True)
150+
151+
video: Mapped[VideoAnnotation] = composite(
152+
mapped_column("_video_path", nullable=True),
153+
mapped_column("_video_length", nullable=True),
154+
)
155+
point_cloud: Mapped[PointCloudAnnotation] = composite(
156+
mapped_column("_point_cloud_path", nullable=True),
157+
mapped_column("_point_cloud_quality_score", nullable=True),
158+
mapped_column("_point_cloud_n_points", nullable=True),
159+
)
160+
# the bigger the better
161+
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)

0 commit comments

Comments
 (0)