Skip to content

Commit e1525b0

Browse files
authored
Add shape::, array_ref::, and array::bounds() (#105)
* Add .vscode to gitignore * Add shape::, array_ref::, and array::bounds() Helps avoid the common mistake of creating an array from an existing argument's shape, inheriting the strides and possibly violating compile-time constraints. (e.g. creating a planar image from an interleaved image's shape).
1 parent 1d1624b commit e1525b0

File tree

4 files changed

+78
-0
lines changed

4 files changed

+78
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ docs/*
3636

3737
# Visual Studio folder status
3838
.vs
39+
.vscode
3940

4041
# perf files
4142
perf.data

include/array/array.h

+31
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,12 @@ NDARRAY_HOST_DEVICE const DimsSrc& assert_dims_compatible(const DimsSrc& src) {
10301030
return src;
10311031
}
10321032

1033+
/** Return a tuple of generic `dims` with same min and extents and all strides set to `unresolved`. */
1034+
template <class Dims, size_t... Is>
1035+
auto bounds_tuple(const Dims& dims, index_sequence<Is...>) {
1036+
return std::make_tuple(dim<>(std::get<Is>(dims).min(), std::get<Is>(dims).extent(), unresolved)...);
1037+
}
1038+
10331039
} // namespace internal
10341040

10351041
template <class... Dims>
@@ -1293,6 +1299,13 @@ class shape {
12931299
NDARRAY_HOST_DEVICE index_t rows() const { return i().extent(); }
12941300
NDARRAY_HOST_DEVICE index_t columns() const { return j().extent(); }
12951301

1302+
/** Returns a shape with dynamic dims, with the same min and extents but
1303+
* strides initialized to `nda::unresolved`. This can be used to create
1304+
* an array with the same dimensions but different compile-time constraints. */
1305+
NDARRAY_HOST_DEVICE auto bounds() const {
1306+
return make_shape_from_tuple(internal::bounds_tuple(dims_, dim_indices()));
1307+
}
1308+
12961309
/** A shape is equal to another shape if the dim objects of each
12971310
* dimension from both shapes are equal. */
12981311
template <class... OtherDims, class = enable_if_same_rank<OtherDims...>>
@@ -2115,6 +2128,15 @@ class array_ref {
21152128
}
21162129
const nda::dim<> dim(size_t d) const { return shape_.dim(d); }
21172130
NDARRAY_HOST_DEVICE size_type size() const { return shape_.size(); }
2131+
/** Returns a shape with dynamic dims, with the same min and extents but
2132+
* strides initialized to `nda::unresolved`. This can be used to create
2133+
* an array with the same dimensions but different compile-time constraints:
2134+
*
2135+
* nda::array_ref<T, SrcShape> ref(...);
2136+
* nda::array<T, DstShape> y(ref.bounds()); // Compact array with the same `min`
2137+
* // and `extents` as `ref`.
2138+
*/
2139+
NDARRAY_HOST_DEVICE auto bounds() const { return shape_.bounds(); }
21182140
NDARRAY_HOST_DEVICE bool empty() const { return base() != nullptr ? shape_.empty() : true; }
21192141
NDARRAY_HOST_DEVICE bool is_compact() const { return shape_.is_compact(); }
21202142

@@ -2577,6 +2599,15 @@ class array {
25772599
}
25782600
const nda::dim<> dim(size_t d) const { return shape_.dim(d); }
25792601
size_type size() const { return shape_.size(); }
2602+
/** Returns a shape with dynamic dims, with the same min and extents but
2603+
* strides initialized to `nda::unresolved`. This can be used to create
2604+
* an array with the same dimensions but different compile-time constraints:
2605+
*
2606+
* nda::array<T, SrcShape> other(...);
2607+
* nda::array<T, DstShape> y(other.bounds()); // Compact array with the same `min`
2608+
* // and `extents` as `other`.
2609+
*/
2610+
auto bounds() const { return shape_.bounds(); }
25802611
bool empty() const { return shape_.empty(); }
25812612
bool is_compact() const { return shape_.is_compact(); }
25822613

test/array.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ TEST(array_default_constructor) {
5959
sparse.clear();
6060
}
6161

62+
TEST(array_construct_from_bounds) {
63+
// Illustrate creating an array from incompatible compile-time
64+
// shapes using bounds().
65+
using SrcShape = shape<dense_dim<>, dim<>>;
66+
using DstShape = shape<dim<>, dense_dim<>>; // dense_dim is swapped.
67+
68+
SrcShape src_shape({-1, 10}, {2, 5, /*stride =*/100});
69+
auto src = array<int, SrcShape>(src_shape);
70+
auto dst = array<int, DstShape>(src.bounds());
71+
72+
// min and extent are preserved.
73+
ASSERT_EQ(src.dim<0>().min(), dst.dim<0>().min());
74+
ASSERT_EQ(src.dim<0>().extent(), dst.dim<0>().extent());
75+
ASSERT_EQ(src.dim<1>().min(), dst.dim<1>().min());
76+
ASSERT_EQ(src.dim<1>().extent(), dst.dim<1>().extent());
77+
// Ensure that `dst` did not inherit the strides of its parent and is
78+
// created compact.
79+
ASSERT(!src.is_compact());
80+
ASSERT(dst.is_compact());
81+
}
82+
6283
TEST(array_static_convertibility) {
6384
using A0 = array_of_rank<int, 0>;
6485
using A3 = array_of_rank<int, 3>;

test/shape.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ TEST(shape_scalar) {
2626
ASSERT_EQ(s.flat_extent(), 1);
2727
ASSERT_EQ(s.size(), 1);
2828
ASSERT_EQ(s(), 0);
29+
ASSERT_EQ(s.bounds().rank(), 0);
2930
}
3031

3132
TEST(shape_1d) {
@@ -205,6 +206,30 @@ TEST(auto_strides) {
205206
test_auto_strides<10>();
206207
}
207208

209+
TEST(bounds) {
210+
dim</* Min= */ 0, /* Extent= */ 10> x;
211+
dense_dim<> y(-2, 12);
212+
auto s = make_shape(x, y);
213+
s.resolve();
214+
215+
auto bounds = s.bounds();
216+
// Returns a generic shape w/o compile-time min, extents, or strides.
217+
ASSERT_EQ(bounds.dim<0>().Min, dynamic);
218+
ASSERT_EQ(bounds.dim<0>().Extent, dynamic);
219+
ASSERT_EQ(bounds.dim<0>().Stride, dynamic);
220+
ASSERT_EQ(bounds.dim<1>().Min, dynamic);
221+
ASSERT_EQ(bounds.dim<1>().Extent, dynamic);
222+
ASSERT_EQ(bounds.dim<1>().Stride, dynamic);
223+
// Check that dynamic min and extents are preserved.
224+
ASSERT_EQ(bounds.dim<0>().min(), x.min());
225+
ASSERT_EQ(bounds.dim<0>().extent(), x.extent());
226+
ASSERT_EQ(bounds.dim<1>().min(), y.min());
227+
ASSERT_EQ(bounds.dim<1>().extent(), y.extent());
228+
// Bounds have strides set to unresolved.
229+
ASSERT_EQ(bounds.dim<0>().stride(), nda::unresolved);
230+
ASSERT_EQ(bounds.dim<1>().stride(), nda::unresolved);
231+
}
232+
208233
TEST(broadcast_dim) {
209234
dim<> x(0, 10, 1);
210235
broadcast_dim<> y;

0 commit comments

Comments
 (0)