Skip to content

Commit 61b9165

Browse files
authored
Indexing API (#166)
* Start of Indexing api Signed-off-by: Ryan Nett <[email protected]> * Add javadoc Signed-off-by: Ryan Nett <[email protected]> * codegen Signed-off-by: Ryan Nett <[email protected]> * op test Signed-off-by: Ryan Nett <[email protected]> * fix test Signed-off-by: Ryan Nett <[email protected]> * explain expected shape, fix slice input Signed-off-by: Ryan Nett <[email protected]> * add a final Signed-off-by: Ryan Nett <[email protected]> * fix constructor visibility Signed-off-by: Ryan Nett <[email protected]> * fix range check Signed-off-by: Ryan Nett <[email protected]> * start of adapting ndarray indexing Signed-off-by: Ryan Nett <[email protected]> * remove old Index class Signed-off-by: Ryan Nett <[email protected]> * test fix for rebase Signed-off-by: Ryan Nett <[email protected]> * test fix for rebase Signed-off-by: Ryan Nett <[email protected]> * Out of bounds warnings Signed-off-by: Ryan Nett <[email protected]> * Remove extra classes, plus a few fixes Signed-off-by: Ryan Nett <[email protected]> * Tests Signed-off-by: Ryan Nett <[email protected]> * ToString methods Signed-off-by: Ryan Nett <[email protected]> * Cleanup and formatting Signed-off-by: Ryan Nett <[email protected]> * Cleanup and formatting Signed-off-by: Ryan Nett <[email protected]> * Javadocs cleanup, new names Signed-off-by: Ryan Nett <[email protected]> * Split Slice into nullability cases Signed-off-by: Ryan Nett <[email protected]> * Change benchmark to fork once by default Signed-off-by: Ryan Nett <[email protected]> * Remove expand Signed-off-by: Ryan Nett <[email protected]> * remove tensor references Signed-off-by: Ryan Nett <[email protected]>
1 parent 4802fd2 commit 61b9165

File tree

24 files changed

+1299
-225
lines changed

24 files changed

+1299
-225
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.tensorflow.ndarray.impl.dimension;
1919

2020
import java.util.Arrays;
21-
import org.tensorflow.ndarray.index.Index;
2221
import org.tensorflow.ndarray.Shape;
22+
import org.tensorflow.ndarray.index.Index;
2323

2424
public class DimensionalSpace {
2525

@@ -35,24 +35,42 @@ public static DimensionalSpace create(Shape shape) {
3535
}
3636

3737
public RelativeDimensionalSpace mapTo(Index[] indices) {
38-
if (dimensions == null || indices.length > dimensions.length) {
38+
if (dimensions == null) {
3939
throw new ArrayIndexOutOfBoundsException();
4040
}
4141
int dimIdx = 0;
42+
int indexIdx = 0;
4243
int newDimIdx = 0;
4344
int segmentationIdx = -1;
4445
long initialOffset = 0;
4546

46-
Dimension[] newDimensions = new Dimension[dimensions.length];
47-
while (dimIdx < indices.length) {
47+
int newAxes = 0;
48+
boolean seenEllipsis = false;
49+
for (Index idx : indices) {
50+
if (idx.isNewAxis()) {
51+
newAxes += 1;
52+
}
53+
if (idx.isEllipsis()) {
54+
if (seenEllipsis) {
55+
throw new IllegalArgumentException("Only one ellipsis allowed");
56+
} else {
57+
seenEllipsis = true;
58+
}
59+
}
60+
}
61+
int newLength = dimensions.length + newAxes;
62+
63+
Dimension[] newDimensions = new Dimension[newLength];
64+
while (indexIdx < indices.length) {
4865

49-
if (indices[dimIdx].isPoint()) {
66+
if (indices[indexIdx].isPoint()) {
5067
// When an index targets a single point in a given dimension, calculate the offset of this
5168
// point and cumulate the offset of any subsequent point as well
5269
long offset = 0;
5370
do {
54-
offset += indices[dimIdx].mapCoordinate(0, dimensions[dimIdx]);
55-
} while (++dimIdx < indices.length && indices[dimIdx].isPoint());
71+
offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]);
72+
dimIdx++;
73+
} while (++indexIdx < indices.length && indices[indexIdx].isPoint());
5674

5775
// If this is the first index, then the offset is the position of the whole dimension
5876
// space within the original one. If not, then we apply the offset to the last vectorial
@@ -65,14 +83,47 @@ public RelativeDimensionalSpace mapTo(Index[] indices) {
6583
segmentationIdx = newDimIdx - 1;
6684
}
6785

86+
} else if (indices[indexIdx].isNewAxis()) {
87+
long newSize;
88+
if (dimIdx == 0) {
89+
// includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues
90+
// elsewhere
91+
newSize = dimensions[0].numElements() * dimensions[0].elementSize();
92+
} else {
93+
newSize = dimensions[dimIdx - 1].elementSize();
94+
}
95+
96+
newDimensions[newDimIdx] = new Axis(1, newSize);
97+
segmentationIdx = newDimIdx; // is this correct?
98+
++newDimIdx;
99+
++indexIdx;
100+
} else if (indices[indexIdx].isEllipsis()) {
101+
int remainingDimensions = dimensions.length - dimIdx;
102+
int requiredDimensions = 0;
103+
for (int i = indexIdx + 1; i < indices.length; i++) {
104+
if (!indices[i].isNewAxis()) {
105+
requiredDimensions++;
106+
}
107+
}
108+
// while the number of dimensions left < the number of indices that consume axes
109+
while (remainingDimensions > requiredDimensions) {
110+
Dimension dim = dimensions[dimIdx++];
111+
if (dim.isSegmented()) {
112+
segmentationIdx = newDimIdx;
113+
}
114+
newDimensions[newDimIdx++] = dim;
115+
remainingDimensions--;
116+
}
117+
indexIdx++;
68118
} else {
69119
// Map any other index to the appropriate dimension of this space
70-
Dimension newDimension = indices[dimIdx].apply(dimensions[dimIdx++]);
120+
Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]);
71121
newDimensions[newDimIdx] = newDimension;
72122
if (newDimension.isSegmented()) {
73123
segmentationIdx = newDimIdx;
74124
}
75125
++newDimIdx;
126+
++indexIdx;
76127
}
77128
}
78129

ndarray/src/main/java/org/tensorflow/ndarray/index/All.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,19 @@ public Dimension apply(Dimension dim) {
3939

4040
private All() {
4141
}
42+
43+
@Override
44+
public boolean beginMask() {
45+
return true;
46+
}
47+
48+
@Override
49+
public boolean endMask() {
50+
return true;
51+
}
52+
53+
@Override
54+
public String toString() {
55+
return All.class.getSimpleName() + "()";
56+
}
4257
}

ndarray/src/main/java/org/tensorflow/ndarray/index/At.java

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717
package org.tensorflow.ndarray.index;
1818

19+
import java.util.StringJoiner;
1920
import org.tensorflow.ndarray.impl.dimension.Dimension;
2021

2122
final class At implements Index {
@@ -27,22 +28,47 @@ public long numElements(Dimension dim) {
2728

2829
@Override
2930
public long mapCoordinate(long coordinate, Dimension dim) {
30-
return dim.positionOf(coord); // TODO validate coordinate is 0?
31+
long coord = this.coord >= 0 ? this.coord : dim.numElements() + this.coord;
32+
return dim.positionOf(coord);
3133
}
3234

3335
@Override
3436
public Dimension apply(Dimension dim) {
35-
throw new IllegalStateException(); // FIXME?
37+
if (!keepDim) {
38+
throw new UnsupportedOperationException("Should be handled in DimensionalSpace.");
39+
}
40+
41+
return dim.withIndex(this);
3642
}
3743

3844
@Override
3945
public boolean isPoint() {
40-
return true;
46+
return !keepDim;
4147
}
4248

43-
At(long coord) {
49+
At(long coord, boolean keepDim) {
4450
this.coord = coord;
51+
this.keepDim = keepDim;
4552
}
4653

4754
private final long coord;
55+
private final boolean keepDim;
56+
57+
@Override
58+
public long begin() {
59+
return coord;
60+
}
61+
62+
@Override
63+
public long end() {
64+
return coord + 1;
65+
}
66+
67+
@Override
68+
public String toString() {
69+
return new StringJoiner(", ", At.class.getSimpleName() + "(", ")")
70+
.add("coord=" + coord)
71+
.add("keepDim=" + keepDim)
72+
.toString();
73+
}
4874
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2020 The TensorFlow Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -12,26 +12,37 @@
1212
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
15-
=======================================================================
15+
==============================================================================
1616
*/
1717
package org.tensorflow.ndarray.index;
1818

1919
import org.tensorflow.ndarray.impl.dimension.Dimension;
2020

21-
final class Even implements Index {
21+
final class Ellipsis implements Index {
2222

23-
static final Even INSTANCE = new Even();
23+
static final Ellipsis INSTANCE = new Ellipsis();
24+
25+
private Ellipsis() {
26+
27+
}
2428

2529
@Override
2630
public long numElements(Dimension dim) {
27-
return (dim.numElements() >> 1) + (dim.numElements() % 2);
31+
throw new UnsupportedOperationException("Should be handled in DimensionalSpace.");
2832
}
2933

3034
@Override
3135
public long mapCoordinate(long coordinate, Dimension dim) {
32-
return coordinate << 1;
36+
throw new UnsupportedOperationException("Should be handled in DimensionalSpace.");
3337
}
3438

35-
private Even() {
39+
@Override
40+
public boolean isEllipsis() {
41+
return true;
42+
}
43+
44+
@Override
45+
public String toString() {
46+
return Ellipsis.class.getSimpleName() + "()";
3647
}
3748
}

ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java

Lines changed: 0 additions & 34 deletions
This file was deleted.

ndarray/src/main/java/org/tensorflow/ndarray/index/From.java

Lines changed: 0 additions & 38 deletions
This file was deleted.

ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.tensorflow.ndarray.index;
1717

18+
import java.util.StringJoiner;
1819
import org.tensorflow.ndarray.impl.dimension.Dimension;
1920

2021
/**
@@ -71,4 +72,19 @@ public boolean isPoint() {
7172
private final long stride;
7273
private final long count;
7374
private final long block;
75+
76+
@Override
77+
public String toString() {
78+
return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "Hyperslab(", ")")
79+
.add("start=" + start)
80+
.add("stride=" + stride)
81+
.add("count=" + count)
82+
.add("block=" + block)
83+
.toString();
84+
}
85+
86+
@Override
87+
public boolean isStridedSlicingCompliant() {
88+
return false;
89+
}
7490
}

0 commit comments

Comments
 (0)