Skip to content

Commit 5bde5a3

Browse files
committed
start of adapting ndarray indexing
Signed-off-by: Ryan Nett <[email protected]>
1 parent f228261 commit 5bde5a3

File tree

9 files changed

+426
-15
lines changed

9 files changed

+426
-15
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,38 @@ public RelativeDimensionalSpace mapTo(Index[] indices) {
3939
throw new ArrayIndexOutOfBoundsException();
4040
}
4141
int dimIdx = 0;
42+
int indexIdx = 0;
4243
int newDimIdx = 0;
4344
int segmentationIdx = -1;
4445
long initialOffset = 0;
4546

46-
Dimension[] newDimensions = new Dimension[dimensions.length];
47-
while (dimIdx < indices.length) {
47+
int newAxes = 0;
48+
boolean seenEllipsis = false;
49+
for(Index idx : indices){
50+
if(idx.isNewAxis()){
51+
newAxes += 1;
52+
}
53+
if(idx.isEllipsis()){
54+
if(seenEllipsis){
55+
throw new IllegalArgumentException("Only one ellipsis allowed");
56+
} else {
57+
seenEllipsis = true;
58+
}
59+
}
60+
}
61+
int newLength = dimensions.length + newAxes;
62+
63+
Dimension[] newDimensions = new Dimension[newLength];
64+
while (indexIdx < indices.length) {
4865

4966
if (indices[dimIdx].isPoint()) {
5067
// When an index targets a single point in a given dimension, calculate the offset of this
5168
// point and cumulate the offset of any subsequent point as well
5269
long offset = 0;
5370
do {
54-
offset += indices[dimIdx].mapCoordinate(0, dimensions[dimIdx]);
55-
} while (++dimIdx < indices.length && indices[dimIdx].isPoint());
71+
offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]);
72+
dimIdx++;
73+
} while (++indexIdx < indices.length && indices[indexIdx].isPoint());
5674

5775
// If this is the first index, then the offset is the position of the whole dimension
5876
// space within the original one. If not, then we apply the offset to the last vectorial
@@ -65,14 +83,47 @@ public RelativeDimensionalSpace mapTo(Index[] indices) {
6583
segmentationIdx = newDimIdx - 1;
6684
}
6785

86+
} else if(indices[indexIdx].isNewAxis()) {
87+
long newSize;
88+
if(dimIdx == 0){
89+
// includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues
90+
// elsewhere
91+
newSize = dimensions[0].numElements() * dimensions[0].elementSize();
92+
} else {
93+
newSize = dimensions[dimIdx - 1].elementSize();
94+
}
95+
96+
newDimensions[newDimIdx] = new Axis(1, newSize);
97+
segmentationIdx = newDimIdx; // is this correct?
98+
++newDimIdx;
99+
++indexIdx;
100+
} else if(indices[indexIdx].isEllipsis()){
101+
int remainingDimensions = dimensions.length - dimIdx;
102+
int requiredDimensions = 0;
103+
for(int i = indexIdx + 1 ; i < indices.length ; i++){
104+
if(!indices[i].isNewAxis()){
105+
requiredDimensions++;
106+
}
107+
}
108+
// while the number of dimensions left < the number of indices that consume axes
109+
while(remainingDimensions > requiredDimensions){
110+
Dimension dim = dimensions[dimIdx++];
111+
if (dim.isSegmented()) {
112+
segmentationIdx = newDimIdx;
113+
}
114+
newDimensions[newDimIdx++] = dim;
115+
remainingDimensions--;
116+
}
117+
indexIdx++;
68118
} else {
69119
// Map any other index to the appropriate dimension of this space
70-
Dimension newDimension = indices[dimIdx].apply(dimensions[dimIdx++]);
120+
Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]);
71121
newDimensions[newDimIdx] = newDimension;
72122
if (newDimension.isSegmented()) {
73123
segmentationIdx = newDimIdx;
74124
}
75125
++newDimIdx;
126+
++indexIdx;
76127
}
77128
}
78129

ndarray/src/main/java/org/tensorflow/ndarray/index/All.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import org.tensorflow.ndarray.impl.dimension.Dimension;
2020

21-
final class All implements Index {
21+
final class All implements TensorIndex {
2222

2323
static final All INSTANCE = new All();
2424

@@ -39,4 +39,14 @@ public Dimension apply(Dimension dim) {
3939

4040
private All() {
4141
}
42+
43+
@Override
44+
public boolean beginMask() {
45+
return true;
46+
}
47+
48+
@Override
49+
public boolean endMask() {
50+
return true;
51+
}
4252
}

ndarray/src/main/java/org/tensorflow/ndarray/index/At.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import org.tensorflow.ndarray.impl.dimension.Dimension;
2020

21-
final class At implements Index {
21+
final class At implements TensorIndex {
2222

2323
@Override
2424
public long numElements(Dimension dim) {
@@ -27,22 +27,45 @@ public long numElements(Dimension dim) {
2727

2828
@Override
2929
public long mapCoordinate(long coordinate, Dimension dim) {
30+
long coord = this.coord > 0 ? this.coord : dim.numElements() - this.coord;
3031
return dim.positionOf(coord); // TODO validate coordinate is 0?
3132
}
3233

3334
@Override
3435
public Dimension apply(Dimension dim) {
35-
throw new IllegalStateException(); // FIXME?
36+
if(keepDim){
37+
return dim.withIndex(this);
38+
}
39+
else {
40+
throw new IllegalStateException(); // FIXME?
41+
}
3642
}
3743

3844
@Override
3945
public boolean isPoint() {
40-
return true;
46+
return !keepDim;
4147
}
4248

43-
At(long coord) {
49+
At(long coord, boolean keepDim) {
4450
this.coord = coord;
51+
this.keepDim = keepDim;
4552
}
4653

4754
private final long coord;
55+
private final boolean keepDim;
56+
57+
@Override
58+
public long begin() {
59+
return coord;
60+
}
61+
62+
@Override
63+
public long end() {
64+
return coord + 1;
65+
}
66+
67+
@Override
68+
public boolean shrinkAxisMask() {
69+
return !keepDim;
70+
}
4871
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow.ndarray.index;
18+
19+
import org.tensorflow.ndarray.impl.dimension.Dimension;
20+
21+
final class Ellipsis implements TensorIndex{
22+
23+
static final Ellipsis INSTANCE = new Ellipsis();
24+
25+
private Ellipsis(){
26+
27+
}
28+
29+
@Override
30+
public long numElements(Dimension dim) {
31+
throw new IllegalStateException();
32+
}
33+
34+
@Override
35+
public long mapCoordinate(long coordinate, Dimension dim) {
36+
throw new IllegalStateException();
37+
}
38+
39+
@Override
40+
public boolean isEllipsis() {
41+
return true;
42+
}
43+
44+
@Override
45+
public boolean ellipsisMask() {
46+
return true;
47+
}
48+
}

ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,19 @@ default Dimension apply(Dimension dim) {
7474
default boolean isPoint() {
7575
return false;
7676
}
77+
78+
/**
79+
* Returns true if this index is a new axis, adding a dimension of size 1
80+
*/
81+
default boolean isNewAxis() {
82+
return false;
83+
}
84+
85+
/**
86+
* Returns true if this index is an ellipsis, expanding to take as many dimensions as possible
87+
* (and applying all() to them)
88+
*/
89+
default boolean isEllipsis() {
90+
return false;
91+
}
7792
}

ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ public final class Indices {
4040
* @param coord coordinate of the element on the indexed axis
4141
* @return index
4242
*/
43-
public static Index at(long coord) {
44-
return new At(coord);
43+
public static TensorIndex at(long coord) {
44+
return new At(coord, false);
4545
}
4646

4747
/**
@@ -54,11 +54,50 @@ public static Index at(long coord) {
5454
* @return index
5555
* @throws IllegalRankException if {@code coord} is not a scalar (rank 0)
5656
*/
57-
public static Index at(NdArray<? extends Number> coord) {
57+
public static TensorIndex at(NdArray<? extends Number> coord) {
5858
if (coord.rank() > 0) {
5959
throw new IllegalRankException("Only scalars are accepted as a value index");
6060
}
61-
return new At(coord.getObject().longValue());
61+
return new At(coord.getObject().longValue(), false);
62+
}
63+
64+
/**
65+
* A coordinate that selects a specific element on a given dimension.
66+
*
67+
* <p>When this index is applied to a given dimension, the dimension is resolved as a
68+
* single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank.
69+
* If {@code} keepDim is true, the dimension is collapsed down to one element.
70+
*
71+
* <p>For example, given a 3D matrix on the axis [x, y, z], if
72+
* {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its
73+
* number of elements is {@code x.numElements()}
74+
*
75+
* @param coord coordinate of the element on the indexed axis
76+
* @param keepDim whether to remove the dimension.
77+
* @return index
78+
*/
79+
public static TensorIndex at(long coord, boolean keepDim) {
80+
return new At(coord, keepDim);
81+
}
82+
83+
/**
84+
* A coordinate that selects a specific element on a given dimension.
85+
*
86+
* <p>This is equivalent to call {@link #at(long, boolean)} but where the value of the coordinate is
87+
* provided by an N-dimensional array.
88+
* <p>
89+
* If {@code} keepDim is true, the dimension is collapsed down to one element instead of being removed.
90+
*
91+
* @param coord scalar indicating the coordinate of the element on the indexed axis
92+
* @return index
93+
* @param keepDim whether to remove the dimension.
94+
* @throws IllegalRankException if {@code coord} is not a scalar (rank 0)
95+
*/
96+
public static TensorIndex at(NdArray<? extends Number> coord, boolean keepDim) {
97+
if (coord.rank() > 0) {
98+
throw new IllegalRankException("Only scalars are accepted as a value index");
99+
}
100+
return new At(coord.getObject().longValue(), keepDim);
62101
}
63102

64103
/**
@@ -72,7 +111,7 @@ public static Index at(NdArray<? extends Number> coord) {
72111
*
73112
* @return index
74113
*/
75-
public static Index all() {
114+
public static TensorIndex all() {
76115
return All.INSTANCE;
77116
}
78117

@@ -216,4 +255,22 @@ public static Index flip() {
216255
public static Index hyperslab(long start, long stride, long count, long block) {
217256
return new Hyperslab(start, stride, count, block);
218257
}
258+
259+
//TODO comments, tests, remove extra classes in favor of helper methods
260+
261+
public static TensorIndex newAxis(){
262+
return NewAxis.INSTANCE;
263+
}
264+
265+
public static TensorIndex ellipsis(){
266+
return Ellipsis.INSTANCE;
267+
}
268+
269+
public static TensorIndex expand(){
270+
return ellipsis();
271+
}
272+
273+
public static TensorIndex slice(Long start, Long end, long stride){
274+
return new Slice(start, end, stride);
275+
}
219276
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow.ndarray.index;
18+
19+
import org.tensorflow.ndarray.impl.dimension.Dimension;
20+
21+
final class NewAxis implements TensorIndex {
22+
23+
static final NewAxis INSTANCE = new NewAxis();
24+
25+
private NewAxis(){
26+
27+
}
28+
29+
@Override
30+
public long numElements(Dimension dim) {
31+
return 1;
32+
}
33+
34+
@Override
35+
public long mapCoordinate(long coordinate, Dimension dim) {
36+
return coordinate;
37+
}
38+
39+
@Override
40+
public Dimension apply(Dimension dim) {
41+
throw new IllegalStateException();
42+
}
43+
44+
@Override
45+
public boolean isNewAxis() {
46+
return true;
47+
}
48+
49+
@Override
50+
public boolean newAxisMask() {
51+
return true;
52+
}
53+
}

0 commit comments

Comments
 (0)