Skip to content

Commit d2119c2

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Serialising dynamic arrays in SQL; read-only SQLite connection in SQL Dataset
Summary: 1. We may need to store arrays of unknown shape in the database. It implements and tests serialisation. 2. Previously, when an inexisting metadata file was passed to SqlIndexDataset, it would try to open it and create an empty file, then crash. We now open the file in a read-only mode, so the error message is more intuitive. Note that the implementation is SQLite specific. Reviewed By: bottler Differential Revision: D46047857 fbshipit-source-id: 3064ae4f8122b4fc24ad3d6ab696572ebe8d0c26
1 parent ff80183 commit d2119c2

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

pytorch3d/implicitron/dataset/orm_types.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,35 @@
3333

3434

3535
# these produce policies to serialize structured types to blobs
36-
def ArrayTypeFactory(shape):
36+
def ArrayTypeFactory(shape=None):
37+
if shape is None:
38+
39+
class VariableShapeNumpyArrayType(TypeDecorator):
40+
impl = LargeBinary
41+
42+
def process_bind_param(self, value, dialect):
43+
if value is None:
44+
return None
45+
46+
ndim_bytes = np.int32(value.ndim).tobytes()
47+
shape_bytes = np.array(value.shape, dtype=np.int64).tobytes()
48+
value_bytes = value.astype(np.float32).tobytes()
49+
return ndim_bytes + shape_bytes + value_bytes
50+
51+
def process_result_value(self, value, dialect):
52+
if value is None:
53+
return None
54+
55+
ndim = np.frombuffer(value[:4], dtype=np.int32)[0]
56+
value_start = 4 + 8 * ndim
57+
shape = np.frombuffer(value[4:value_start], dtype=np.int64)
58+
assert shape.shape == (ndim,)
59+
return np.frombuffer(value[value_start:], dtype=np.float32).reshape(
60+
shape
61+
)
62+
63+
return VariableShapeNumpyArrayType
64+
3765
class NumpyArrayType(TypeDecorator):
3866
impl = LargeBinary
3967

@@ -158,4 +186,4 @@ class SqlSequenceAnnotation(Base):
158186
mapped_column("_point_cloud_n_points", nullable=True),
159187
)
160188
# the bigger the better
161-
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)
189+
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column()

pytorch3d/implicitron/dataset/sql_dataset.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,10 @@ def __post_init__(self) -> None:
142142
run_auto_creation(self)
143143
self.frame_data_builder.path_manager = self.path_manager
144144

145-
# pyre-ignore
146-
self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}")
145+
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
146+
self._sql_engine = sa.create_engine(
147+
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
148+
)
147149

148150
sequences = self._get_filtered_sequences_if_any()
149151

tests/implicitron/test_orm_types.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
11+
from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory
1212

1313

1414
class TestOrmTypes(unittest.TestCase):
@@ -35,3 +35,28 @@ def test_tuple_serialization_2d(self):
3535
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
3636
# we use float32 to serialise
3737
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
38+
39+
def test_array_serialization_none(self):
40+
ttype = ArrayTypeFactory((3, 3))()
41+
output = ttype.process_bind_param(None, None)
42+
self.assertIsNone(output)
43+
output = ttype.process_result_value(output, None)
44+
self.assertIsNone(output)
45+
46+
def test_array_serialization(self):
47+
for input_list in [[1, 2, 3], [[4.5, 6.7], [8.9, 10.0]]]:
48+
input_array = np.array(input_list)
49+
50+
# first, dynamic-size array
51+
ttype = ArrayTypeFactory()()
52+
output = ttype.process_bind_param(input_array, None)
53+
input_hat = ttype.process_result_value(output, None)
54+
self.assertEqual(input_hat.dtype, np.float32)
55+
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)
56+
57+
# second, fixed-size array
58+
ttype = ArrayTypeFactory(tuple(input_array.shape))()
59+
output = ttype.process_bind_param(input_array, None)
60+
input_hat = ttype.process_result_value(output, None)
61+
self.assertEqual(input_hat.dtype, np.float32)
62+
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)

0 commit comments

Comments
 (0)